diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..f6cb8ad931f5442a5e2276ed66ba7b5a733b82c6 --- /dev/null +++ b/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: Google diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..229856aa4366c092e94eb573dbbd67756974a3ae --- /dev/null +++ b/.flake8 @@ -0,0 +1,22 @@ +[flake8] +ignore = + ;W503 line break before binary operator + W503, + ;E203 whitespace before ':' + E203, + +; exclude file +exclude = + .tox, + .git, + __pycache__, + build, + dist, + *.pyc, + *.egg-info, + .cache, + .eggs + +max-line-length = 120 + +per-file-ignores = __init__.py:F401 diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..c09c1030843ba057b64a093b31ce6f70df323f68 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,38 @@ +name: ๐Ÿ› Bug Report +description: Create a report to help us reproduce and fix the bug +title: "[BUG]: " +labels: [bug] + +body: +- type: markdown + attributes: + value: > + #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: textarea + attributes: + label: ๐Ÿ› Describe the bug + description: | + **Describe the bug** + A clear and concise description of what the bug is. + **To Reproduce** + Steps or code snippet to reproduce the behavior. + **Expected behavior** + A clear and concise description of what you expected to happen. + **Screenshots** + If applicable, add screenshots to help explain your problem. + **Optional: Affiliation** + Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + placeholder: | + A clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please provide the environment information, eg. CUDA/cuDNN/NCCL/Python/PyTorch version. + +- type: markdown + attributes: + value: > + Thanks for contributing ๐ŸŽ‰! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..673b1274c94b147537e16385a208ec31705c5acb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,11 @@ +blank_issues_enabled: true +contact_links: + - name: โ“ Simple question - Slack Chat + url: https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w + about: This issue tracker is not for technical support. Please use our Slack chat, and ask the community for help. + - name: โ“ Simple question - WeChat + url: https://github.com/hpcaitech/ColossalAI/blob/main/docs/images/WeChat.png + about: This issue tracker is not for technical support. Please use WeChat, and ask the community for help. + - name: ๐Ÿ˜Š Advanced question - GitHub Discussions + url: https://github.com/hpcaitech/ColossalAI/discussions + about: Use GitHub Discussions for advanced and unanswered technical questions, requiring a maintainer's answer. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..511997e2ee813d2bf84bf6fd7424ad580fc01bc3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,29 @@ +name: ๐Ÿ“š Documentation +description: Report an issue related to https://www.colossalai.org/ +title: "[DOC]: " +labels: [documentation] + +body: +- type: markdown + attributes: + value: > + #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: textarea + attributes: + label: ๐Ÿ“š The doc issue + description: | + **Description** What content in [Documentation](https://www.colossalai.org/) is an issue? + **Location** Where is the issue location? + **Expectation** What is your expected content about it? + **Screenshots** If applicable, add screenshots to help explain your problem. + **Suggestions** Tell us how we could improve the documentation. + **Optional: Affiliation** Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + placeholder: | + A clear and concise description of the issue. + validations: + required: true + +- type: markdown + attributes: + value: > + Thanks for contributing ๐ŸŽ‰! diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..d05bc25f6f4161b2df404d5bfff3c59d10c9fc36 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,35 @@ +name: ๐Ÿš€ Feature request +description: Suggest an idea for this project +title: "[FEATURE]: " +labels: [enhancement] + +body: +- type: markdown + attributes: + value: > + #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: textarea + attributes: + label: Describe the feature + description: | + **Is your feature request related to a problem? Please describe.** + A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + **Describe the solution you'd like** + A clear and concise description of what you want to happen. + **Describe alternatives you've considered** + A clear and concise description of any alternative solutions or features you've considered. + **Screenshots** + If applicable, add screenshots to help explain your problem. + **Suggest a potential alternative/fix** + Tell us how we could improve this project. + **Optional: Affiliation** + Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + placeholder: | + A clear and concise description of your idea. + validations: + required: true + +- type: markdown + attributes: + value: > + Thanks for contributing ๐ŸŽ‰! diff --git a/.github/ISSUE_TEMPLATE/proposal.yml b/.github/ISSUE_TEMPLATE/proposal.yml new file mode 100644 index 0000000000000000000000000000000000000000..614ef77751ad88f6dd55856b89fe91ffd1d294b8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/proposal.yml @@ -0,0 +1,47 @@ +name: ๐Ÿ’ฅ Proposal +description: Propose a non-trivial change to Colossal-AI +title: "[PROPOSAL]: " +labels: [enhancement] + +body: +- type: markdown + attributes: + value: | + Common reasons for proposals include: + + - Altering the infrastructure; + - Bumping a critical dependency's major version; + - A significant improvement in user-friendliness; + - Significant refactor; + - Optional: Affiliation/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + - ... + + Please note this is not for feature request or bug template; such action could make us identify the issue wrongly and close it without doing anything. + + We give you maximum freedom to write an elaborated proposal illustrating why you think the change is beneficial for us, and what steps we should take to turn this into reality. + + +- type: textarea + attributes: + label: Proposal + description: A clear and concise description of what the proposal is. + validations: + required: true + +- type: checkboxes + attributes: + label: Self-service + description: | + If you feel like you could contribute to this issue, please check the box below. This would tell us and other people looking for contributions that someone's working on it. + If you do check this box, please send a pull request within 7 days after a maintainer's approval so we can still delegate this to someone else. + + Proposals usually involve significant code changes, so please reach consensus with the maintainers before rushing to implement it, and make sure you follow the [Contributing Guidelines](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). + This ensures that you don't waste your time and we don't waste ours reading the large diffs. + options: + - label: I'd be willing to do some initial work on this proposal myself. + + +- type: markdown + attributes: + value: > + Thanks for contributing ๐ŸŽ‰! diff --git a/.github/reviewer_list.yml b/.github/reviewer_list.yml new file mode 100644 index 0000000000000000000000000000000000000000..ce1d4849f9d27c48434481dfc32fe2e5617fa85b --- /dev/null +++ b/.github/reviewer_list.yml @@ -0,0 +1,9 @@ +addReviewers: true + +addAssignees: author + +numberOfReviewers: 1 + +reviewers: + - frankleeeee + - kurisusnowdeng diff --git a/.github/workflows/assign_reviewer.yml b/.github/workflows/assign_reviewer.yml new file mode 100644 index 0000000000000000000000000000000000000000..6ebb3398265e55c2cf51a03ce03abb2bea86e2da --- /dev/null +++ b/.github/workflows/assign_reviewer.yml @@ -0,0 +1,18 @@ +name: Assign Reviewers for Team + +on: + pull_request: + types: [opened] + +jobs: + assign_reviewer: + name: Assign Reviewer for PR + runs-on: ubuntu-latest + if: | + github.event.pull_request.draft == false && github.base_ref == 'main' + && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + && toJson(github.event.pull_request.requested_reviewers) == '[]' + steps: + - uses: kentaro-m/auto-assign-action@v1.2.1 + with: + configuration-path: '.github/reviewer_list.yml' diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000000000000000000000000000000000000..36e33b0ab59b1c788868a724495ae94e0f3db6a3 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,48 @@ +name: Build + +on: + pull_request: + types: [synchronize, labeled] + +jobs: + build: + name: Build and Test Colossal-AI + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && + contains( github.event.pull_request.labels.*.name, 'Run Build and Test') + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + - name: Install tensornvme + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + - uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Install Colossal-AI + run: | + [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + pip install -r requirements/requirements.txt + pip install -v -e . + cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ + pip install -r requirements/requirements-test.txt + - name: Unit Testing + run: | + PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + NCCL_SHM_DISABLE: 1 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/build_gpu_8.yml b/.github/workflows/build_gpu_8.yml new file mode 100644 index 0000000000000000000000000000000000000000..2a405d86f1dc8514c34c3391702fc6f3ca540bbb --- /dev/null +++ b/.github/workflows/build_gpu_8.yml @@ -0,0 +1,46 @@ +name: Build on 8 GPUs + +on: + schedule: + # run at 00:00 of every Sunday + - cron: '0 0 * * *' + workflow_dispatch: + +jobs: + build: + name: Build and Test Colossal-AI + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, 8-gpu] + container: + image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + - name: Install tensornvme + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + - uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Install Colossal-AI + run: | + [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + pip install -r requirements/requirements.txt + pip install -v -e . + cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ + pip install -r requirements/requirements-test.txt + - name: Unit Testing + run: | + gpu_used=$(nvidia-smi -i 0 --query-gpu=memory.used --format=csv,noheader,nounits) + [ "$gpu_used" -le "100" ] && PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/close_inactive.yml b/.github/workflows/close_inactive.yml new file mode 100644 index 0000000000000000000000000000000000000000..e7dec44309303f4602cab8b3205f25cabb07cf30 --- /dev/null +++ b/.github/workflows/close_inactive.yml @@ -0,0 +1,26 @@ +name: Close inactive issues + +on: + schedule: + - cron: "0 0 * * *" + +jobs: + close-issues: + if: github.event.pull_request.draft == false && github.base_ref == 'main' && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v3 + with: + days-before-issue-stale: 14 + days-before-issue-close: -1 + stale-issue-label: "stale" + stale-issue-message: "This issue is stale because it has been open for 14 days with no activity." +# close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: 14 + days-before-pr-close: -1 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity." +# close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale." + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/compatibility_test.yml b/.github/workflows/compatibility_test.yml new file mode 100644 index 0000000000000000000000000000000000000000..eadd07886106f039153fc96388b46a897f423694 --- /dev/null +++ b/.github/workflows/compatibility_test.yml @@ -0,0 +1,84 @@ +name: Compatibility Test + +on: + workflow_dispatch: + inputs: + torch_version: + type: string + description: torch version, separated by comma + required: true + cuda_version: + type: string + description: cuda version, separated by comma + required: true + +jobs: + matrix_preparation: + name: Prepare Container List + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - id: set-matrix + env: + TORCH_VERSIONS: ${{ inputs.torch_version }} + CUDA_VERSIONS: ${{ inputs.cuda_version }} + run: | + IFS=',' + DOCKER_IMAGE=() + + for tv in $TORCH_VERSIONS + do + for cv in $CUDA_VERSIONS + do + DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tv}-${cv}\"") + done + done + + container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) + container="[${container}]" + echo "$container" + echo "::set-output name=matrix::{\"container\":$(echo "$container")}" + + build: + name: Test for PyTorch Compatibility + needs: matrix_preparation + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.container }} + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 120 + steps: + - name: Install dependencies + run: | + pip install -U pip setuptools wheel --user + - uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + - name: Install tensornvme + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + - uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Install Colossal-AI + run: | + pip install -r requirements/requirements.txt + pip install -v --no-cache-dir . + pip install -r requirements/requirements-test.txt + - name: Unit Testing + run: | + PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + NCCL_SHM_DISABLE: 1 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/draft_github_release_post.yml b/.github/workflows/draft_github_release_post.yml new file mode 100644 index 0000000000000000000000000000000000000000..413714dafa8646d0302274de58917766ac0c6833 --- /dev/null +++ b/.github/workflows/draft_github_release_post.yml @@ -0,0 +1,44 @@ +name: Draft GitHub Release Post + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + types: + - closed + + +jobs: + release: + name: Draft Release Post + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + - name: generate draft + id: generate_draft + run: | + version=v$(cat version.txt) + pip install requests + python ./.github/workflows/scripts/generate_release_draft.py --out $PWD/release_draft.md --version $version + echo "::set-output name=version::$version" + echo "::set-output name=path::$PWD/release_draft.md" + env: + GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.generate_draft.outputs.version }} + release_name: Version ${{ steps.generate_draft.outputs.version }} Release Today! + body_path: ${{ steps.generate_draft.outputs.path }} + draft: True + prerelease: false diff --git a/.github/workflows/release_bdist.yml b/.github/workflows/release_bdist.yml new file mode 100644 index 0000000000000000000000000000000000000000..c9c51df8d0747656772aa9fa35cca347ec72aeb0 --- /dev/null +++ b/.github/workflows/release_bdist.yml @@ -0,0 +1,99 @@ +name: Release bdist wheel + +on: + workflow_dispatch: + inputs: + torch_version: + type: string + description: torch version, separated by comma + required: true + default: "all" + cuda_version: + type: string + description: cuda version, separated by comma + required: true + github_ref: + type: string + description: Branch or Tag + default: 'main' + required: true + +jobs: + matrix_preparation: + name: Prepare Container List + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - id: set-matrix + env: + TORCH_VERSIONS: ${{ inputs.torch_version }} + CUDA_VERSIONS: ${{ inputs.cuda_version }} + run: | + echo $TORCH_VERSIONS + echo $CUDA_VERSIONS + IFS=',' + DOCKER_IMAGE=() + + for cv in $CUDA_VERSIONS + do + DOCKER_IMAGE+=("\"hpcaitech/cuda-conda:${cv}\"") + done + + container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) + container="[${container}]" + echo "$container" + echo "::set-output name=matrix::{\"container\":$(echo "$container")}" + + build: + name: Release bdist wheels + needs: matrix_preparation + if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.container }} + options: --gpus all --rm + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + # cub is for cuda 10.2 + - name: Copy scripts and checkout + run: | + cp -r ./.github/workflows/scripts/* ./ + + # link the cache diretories to current path + ln -s /github/home/conda_pkgs ./conda_pkgs + ln -s /github/home/pip_wheels ./pip_wheels + + # set the conda package path + echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc + + # set safe directory + git config --global --add safe.directory /__w/ColossalAI/ColossalAI + + # check out + git checkout $git_ref + + # get cub package for cuda 10.2 + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + env: + git_ref: ${{ github.event.inputs.github_ref }} + - name: Build bdist wheel + run: | + pip install beautifulsoup4 requests packaging + python ./build_colossalai_wheel.py --torch_version $TORCH_VERSIONS + env: + TORCH_VERSIONS: ${{ inputs.torch_version }} + - name: ๐Ÿš€ Deploy + uses: garygrossgarten/github-action-scp@release + with: + local: all_dist + remote: ${{ secrets.PRIVATE_PYPI_DIR }} + host: ${{ secrets.PRIVATE_PYPI_HOST }} + username: ${{ secrets.PRIVATE_PYPI_USER }} + password: ${{ secrets.PRIVATE_PYPI_PASSWD }} diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..328d232a835657f6510113fe885208cf49bada0c --- /dev/null +++ b/.github/workflows/release_docker.yml @@ -0,0 +1,40 @@ +name: Publish Docker Image to DockerHub + +on: + workflow_dispatch: + release: + types: [published] + +jobs: + release: + name: Publish Docker Image to DockerHub + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: "hpcaitech/docker-in-docker:latest" + options: --gpus all --rm -v /var/run/docker.sock:/var/run/docker.sock + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Build Docker + run: | + version=$(cat version.txt) + docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker + - name: Log in to Docker Hub + uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 + with: + images: hpcaitech/colossalai + - name: Build and push Docker image + uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml new file mode 100644 index 0000000000000000000000000000000000000000..6bc000d1f4f6902f2dde18cf688b82cd3874d05e --- /dev/null +++ b/.github/workflows/release_nightly.yml @@ -0,0 +1,73 @@ +name: Release bdist wheel for Nightly versions + +on: + schedule: + # run at 00:00 of every Sunday + - cron: '0 0 * * 6' + workflow_dispatch: + +jobs: + matrix_preparation: + name: Prepare Container List + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - id: set-matrix + run: | + matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]" + echo $matrix + echo "::set-output name=matrix::{\"container\":$(echo $matrix)}" + + build: + name: Release bdist wheels + needs: matrix_preparation + if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.container }} + options: --gpus all --rm + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + # cub is for cuda 10.2 + - name: Copy scripts and checkout + run: | + cp -r ./.github/workflows/scripts/* ./ + ln -s /github/home/pip_wheels ./pip_wheels + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + - name: Build bdist wheel + run: | + pip install beautifulsoup4 requests packaging + python ./build_colossalai_wheel.py --nightly + - name: ๐Ÿš€ Deploy + uses: garygrossgarten/github-action-scp@release + with: + local: all_dist + remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }} + host: ${{ secrets.PRIVATE_PYPI_HOST }} + username: ${{ secrets.PRIVATE_PYPI_USER }} + password: ${{ secrets.PRIVATE_PYPI_PASSWD }} + remove_old_build: + name: Remove old nightly build + runs-on: ubuntu-latest + needs: build + steps: + - name: executing remote ssh commands using password + uses: appleboy/ssh-action@master + env: + BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }} + with: + host: ${{ secrets.PRIVATE_PYPI_HOST }} + username: ${{ secrets.PRIVATE_PYPI_USER }} + password: ${{ secrets.PRIVATE_PYPI_PASSWD }} + envs: BUILD_DIR + script: | + cd $BUILD_DIR + find . -type f -mtime +0 -exec rm -f {} + + script_stop: true diff --git a/.github/workflows/scripts/build_colossalai_wheel.py b/.github/workflows/scripts/build_colossalai_wheel.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ac16fbc94ad5bd2d89e84e9944cafe6b6d28fb --- /dev/null +++ b/.github/workflows/scripts/build_colossalai_wheel.py @@ -0,0 +1,119 @@ +import argparse +import os +import subprocess +from filecmp import cmp +from functools import cmp_to_key + +import requests +from bs4 import BeautifulSoup +from packaging import version + +WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels' +RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels' +CUDA_HOME = os.environ['CUDA_HOME'] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--torch_version', type=str) + parser.add_argument( + '--nightly', + action='store_true', + help= + 'whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8' + ) + return parser.parse_args() + + +def get_cuda_bare_metal_version(): + raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return bare_metal_major, bare_metal_minor + + +def all_wheel_info(): + page_text = requests.get(WHEEL_TEXT_ROOT_URL).text + soup = BeautifulSoup(page_text) + + all_a_links = soup.find_all('a') + + wheel_info = dict() + + for a_link in all_a_links: + if 'cuda' in a_link.text and '.txt' in a_link.text: + filename = a_link.text + torch_version, cuda_version = filename.rstrip('.txt').split('-') + cuda_version = cuda_version.lstrip('cuda') + + if torch_version not in wheel_info: + wheel_info[torch_version] = dict() + wheel_info[torch_version][cuda_version] = dict() + + file_text = requests.get(f'{RAW_TEXT_FILE_PREFIX}/{filename}').text + lines = file_text.strip().split('\n') + + for line in lines: + parts = line.split('\t') + method, url, python_version = parts[:3] + + if len(parts) > 3: + flags = parts[3] + flags = ' '.join(flags.split('+')) + else: + flags = '' + wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags) + return wheel_info + + +def build_colossalai(wheel_info): + cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version() + cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}' + + for torch_version, cuda_versioned_wheel_info in wheel_info.items(): + for cuda_version, python_versioned_wheel_info in cuda_versioned_wheel_info.items(): + if cuda_version_on_host == cuda_version: + for python_version, wheel_info in python_versioned_wheel_info.items(): + url = wheel_info['url'] + method = wheel_info['method'] + flags = wheel_info['flags'] + filename = url.split('/')[-1].replace('%2B', '+') + cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}' + os.system(cmd) + + +def main(): + args = parse_args() + wheel_info = all_wheel_info() + + # filter wheels on condition + all_torch_versions = list(wheel_info.keys()) + + def _compare_version(a, b): + if version.parse(a) > version.parse(b): + return 1 + else: + return -1 + + all_torch_versions.sort(key=cmp_to_key(_compare_version)) + + if args.nightly: + # only keep the latest version + for key in all_torch_versions[:-1]: + wheel_info.pop(key) + elif args.torch_version != 'all': + torch_versions = args.torch_version.split(',') + # only keep the torch versions specified + for key in all_torch_versions: + if key not in torch_versions: + wheel_info.pop(key) + + build_colossalai(wheel_info) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/build_colossalai_wheel.sh b/.github/workflows/scripts/build_colossalai_wheel.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0d40fd2cc99291d166425dcbfbd72a1e644e8bd --- /dev/null +++ b/.github/workflows/scripts/build_colossalai_wheel.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +method=${1} +url=${2} +filename=${3} +cuda_version=${4} +python_version=${5} +torch_version=${6} +flags=${@:7} + +git reset --hard HEAD +mkdir -p ./all_dist +source activate base +conda create -n $python_version -y python=$python_version +source activate $python_version + +if [ $1 == "pip" ] +then + wget -nc -q -O ./pip_wheels/$filename $url + pip install ./pip_wheels/$filename + +elif [ $1 == 'conda' ] +then + conda install pytorch==$torch_version cudatoolkit=$cuda_version $flags +else + echo Invalid installation method + exit +fi + +if [ $cuda_version == "10.2" ] +then + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ +fi + +python setup.py bdist_wheel +mv ./dist/* ./all_dist +# must remove build to enable compilation for +# cuda extension in the next build +rm -rf ./build +python setup.py clean +conda deactivate +conda env remove -n $python_version diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py new file mode 100644 index 0000000000000000000000000000000000000000..1c407cf14554db7862ed11f260f020db1990b2fb --- /dev/null +++ b/.github/workflows/scripts/generate_release_draft.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# coding: utf-8 + +import argparse +import os +import re + +import requests + +COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' +TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--out', type=str, help='output path for the release draft', required=True) + parser.add_argument('--version', type=str, help='current version to release', required=True) + return parser.parse_args() + + +def get_latest_tag_commit(headers=None): + res = requests.get(url=TAGS_API, headers=headers) + data = res.json() + commit_hash = data[0]['commit']['sha'] + version = data[0]['name'] + return commit_hash, version + + +def get_commit_info(commit_hash, headers=None): + api = f'{COMMIT_API}/{commit_hash}' + res = requests.get(url=api, headers=headers) + return res.json() + + +def get_all_commit_info(since, headers=None): + page = 1 + results = [] + + while True: + api = f'{COMMIT_API}?since={since}&per_page=100&page={page}' + resp = requests.get(url=api, headers=headers) + data = resp.json() + + # exit when no more data + if len(data) == 0: + break + + results.extend(data) + page += 1 + + return results + + +def collate_release_info(commit_info_list): + results = dict() + pattern = pattern = r'\[.*\]' + + for commit_info in commit_info_list: + author = commit_info['commit']['author']['name'] + author_url = commit_info['author']['url'] + msg = commit_info['commit']['message'] + match = re.search(pattern, msg) + + if match: + tag = match.group().lstrip('[').rstrip(']').capitalize() + if tag not in results: + results[tag] = [] + results[tag].append((msg, author, author_url)) + + return results + + +def generate_release_post_markdown(current_version, last_version, release_info): + text = [] + + # add highlights + highlights = "## What's Changed \n\n" + text.append(highlights) + + # add items + for k, v in release_info.items(): + topic = f"### {k} \n" + text.append(topic) + + for msg, author, author_url in v: + # only keep the first line + msg = msg.split('\n')[0] + + item = f'{msg} by [{author}]({author_url})\n' + text.append(f'- {item}') + + text.append('\n') + + # add full change log + text.append( + f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}') + + return text + + +if __name__ == '__main__': + args = parse_args() + token = os.environ['GITHUB_API_TOKEN'] + headers = {'Authorization': token} + + # get previous release tag + last_release_commit, last_version = get_latest_tag_commit(headers) + last_release_commit_info = get_commit_info(last_release_commit, headers=headers) + last_release_date = last_release_commit_info['commit']['author']['date'] + + # get the commits since last release + commit_info = get_all_commit_info(since=last_release_date, headers=headers) + commit_info = commit_info[:-1] # remove the release commit + + # collate into markdown + release_info = collate_release_info(commit_info) + markdown_text = generate_release_post_markdown(args.version, last_version, release_info) + + # write into a file + with open(args.out, 'w') as f: + for line in markdown_text: + f.write(line) diff --git a/.github/workflows/submodule.yml b/.github/workflows/submodule.yml new file mode 100644 index 0000000000000000000000000000000000000000..4ffb261183f121a87eb39dc710cc9838dac11bbf --- /dev/null +++ b/.github/workflows/submodule.yml @@ -0,0 +1,45 @@ +name: Synchronize Submodule + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + sync-submodule: + runs-on: ubuntu-latest + if: github.repository == 'hpcaitech/ColossalAI' + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + ref: 'main' + submodules: true + + - name: echo + run: | + echo ${{github}} + + - name: Git Sumbodule Update + run: | + git pull --recurse-submodules + git submodule update --remote --recursive + + - name: Commit update + run: | + git config --global user.name 'github-actions' + git config --global user.email 'github-actions@github.com' + git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} + git commit -am "Automated submodule synchronization" + + - name: Create Pull Request + uses: peter-evans/create-pull-request@v3 + with: + title: '[Bot] Synchronize Submodule References' + body: | + Automated PR to update submodule commits + committer: GitHub + author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> + assignees: ${{ github.actor }} + delete-branch: true + branch: create-pull-request/patch-sync-submodule diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..40f3f6debeee1c31640b288d7fa682a5ffce21a4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,146 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +# ignore version.py generated by setup.py +colossalai/version.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..1e7631bd87604131daf3baa42be52eee89b5c1b2 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,13 @@ +[submodule "benchmark"] + path = benchmark + url = https://github.com/hpcaitech/ColossalAI-Benchmark.git + branch = main +[submodule "examples"] + path = examples + url = https://github.com/hpcaitech/ColossalAI-Examples.git + branch = main + +[submodule "inference"] + path = inference + url = https://github.com/hpcaitech/EnergonAI.git + branch = main diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000000000000000000000000000000000000..090aa28e39f32da8c0161d5317c710b6c8781641 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +line_length = 120 +multi_line_output=3 +include_trailing_comma = true +ignore_comments = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dc9087af334c2ea4009c2edc0564157f35a92954 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + + - repo: https://github.com/pycqa/isort + rev: 5.10.1 + hooks: + - id: isort + name: sort all imports (python) + + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.32.0 + hooks: + - id: yapf + name: yapf formatter + args: ['--style=.style.yapf', '--parallel', '--in-place'] + + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v13.0.1 + hooks: + - id: clang-format + name: clang formatter + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-yaml + - id: check-merge-conflict + - id: check-case-conflict + - id: trailing-whitespace + - id: end-of-file-fixer + - id: mixed-line-ending + args: ['--fix=lf'] diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98dd0cc4e9791e7b4ac9ff9460345e1b3c6c7a55 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,30 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.9" + # You can also specify other tool versions: + # nodejs: "16" + # rust: "1.55" + # golang: "1.17" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: requirements/requirements.txt + - requirements: docs/requirements.txt diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000000000000000000000000000000000000..05be0dc6a3a598ebd59ff00cec93ce8b809c78b5 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,5 @@ +[style] +based_on_style = google +spaces_before_comment = 4 +split_before_logical_operator = true +column_limit = 120 diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md new file mode 100644 index 0000000000000000000000000000000000000000..bbf1d62f908b454191104af1c6c1276b6178e6d2 --- /dev/null +++ b/CHANGE_LOG.md @@ -0,0 +1,36 @@ +# Change Log + +All notable changes to this project will be documented in this file. + +## v0.0.2 | 2022-02 + +### Added + +- Unified distributed layers +- MoE support +- DevOps tools such as github action, code review automation, etc. +- New project official website + +### Changes + +- refactored the APIs for usability, flexibility and modularity +- adapted PyTorch AMP for tensor parallel +- refactored utilities for tensor parallel and pipeline parallel +- Separated benchmarks and examples as independent repositories +- Updated pipeline parallelism to support non-interleaved and interleaved versions +- refactored installation scripts for convenience + +### Fixed + +- zero level 3 runtime error +- incorrect calculation in gradient clipping + + +## v0.0.1 beta | 2021-10 + +The first beta version of Colossal-AI. Thanks to all contributors for the effort to implement the system. + +### Added + +- Initial architecture of the system +- Features such as tensor parallelism, gradient clipping, gradient accumulation diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..00abcf650158c571411f7b78a0ef4c982365b746 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,141 @@ +# Contributing + +Colossal-AI welcomes any constructive contribution from the community and the team is more than willing to work on problems you have encountered to make it a better project. + +## Environment Setup + +To contribute to Colossal-AI, we would like to first guide you to set up a proper development environment so that you can better implement your code. It is good to install this system from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without repeated installation and uninstallation. Here are the steps to set up the development environment. + +1. Uninstall any existing Colossal-AI distribution. + +```shell +pip uninstall colossalai +``` + +2. Clone the repository to local workspace + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI +``` + +3. The *Get Started* section of [official documentation](https://colossalai.org) has provided instructions to build from source. Follow to instruction to build from source, **but replace the last `pip install` statement with the command below by adding the `-e` flag.** + +```shell +pip install -e . +``` + +## Coding Standards + +### Unit Tests +We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests. + +If you only want to run CPU tests, you can run + +```bash +pytest -m cpu tests/ +``` + +If you have 8 GPUs on your machine, you can run the full test + +```bash +pytest tests/ +``` + +If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch. + + +### Code Style + +We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. + +```shell +# these commands are executed under the Colossal-AI directory +pip install pre-commit +pre-commit install +``` + +Code format checking will be automatically executed when you commit your changes. + + +## Contribution Guide + +You need to follow these steps below to make contribution to the main repository via pull request. You can learn about the details of pull request [here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests). + +### 1. Fork the Official Repository + +Firstly, you need to visit the [Colossal-AI repository](https://github.com/hpcaitech/ColossalAI) and fork into your own account. The `fork` button is at the right top corner of the web page alongside with buttons such as `watch` and `star`. + +Now, you can clone your own forked repository into your local environment. + +```shell +git clone https://github.com//ColossalAI.git +``` + +### 2. Configure Git + +You need to set the official repository as your upstream so that you can synchronize with the latest update in the official repository. You can learn about upstream [here](https://www.atlassian.com/git/tutorials/git-forks-and-upstreams). + +Then add the original repository as upstream + +```shell +cd ColossalAI +git remote add upstream https://github.com/hpcaitech/ColossalAI.git +``` + +you can use the following command to verify that the remote is set. You should see both `origin` and `upstream` in the output. + +```shell +git remote -v +``` + +### 3. Synchronize with Official Repository + +Before you make changes to the codebase, it is always good to fetch the latest updates in the official repository. In order to do so, you can use the commands below. + +```shell +git fetch upstream +git checkout main +git merge upstream/main +git push origin main +``` + +Otherwise, you can click the `fetch upstream` button on the github webpage of the main branch of your forked repository. Then, use these commands to sync. + +``` +git checkout main +git fetch main +``` + +### 4. Choose/Create an Issue for Your Pull Request + +Generally, your code change should be only targeted at one problem. Stacking multiple commits for different problems into one pull request will only make the code review such dire suffering and make the system prone to new bugs as the reviewer may not understand the code logic correctly. Thus, you should choose an existing issue or [create your own issue](https://github.com/hpcaitech/ColossalAI/issues) as your pull request target. If you wish to create a new issue, do use appropriate title and description and add related labels. + + +### 5. Create a New Branch + +You should not make changes to the `main` branch of your forked repository as this might make upstream synchronization difficult. You can create a new branch with the appropriate name. General branch name format should start with `hotfix/` and `feature/`. `hotfix` is for bug fix and `feature` is for addition of a new feature. + + +```shell +git checkout -b +``` + +### 6. Implementation and Code Commit + +Now you can implement your code change in the source code. Remember that you installed the system in development, thus you do not need to uninstall and install to make the code take effect. The code change will be reflected in every new PyThon execution. +You can commit and push the changes to your local repository. The changes should be kept logical, modular and atomic. + +```shell +git add -A +git commit -m "" +git push -u origin +``` + +### 7. Open a Pull Request + +You can now create a pull request on the GitHub webpage of your repository. The source branch is `` of your repository and the target branch should be `main` of `hpcaitech/ColossalAI`. After creating this pull request, you should be able to see it [here](https://github.com/hpcaitech/ColossalAI/pulls). + +Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved. + +In case of code conflict, you should rebase your branch and resolve the conflicts manually. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0528c89ea9ecd51713b60db68ae69702d8d164f7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ +Copyright 2021- HPC-AI Technology Inc. All rights reserved. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021- HPC-AI Technology Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..baf2892701b2959c916bae7fa9f1d167295adaf1 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include *.txt README.md +recursive-include requirements *.txt +recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi diff --git a/README-zh-Hans.md b/README-zh-Hans.md new file mode 100644 index 0000000000000000000000000000000000000000..ad5b72e9fb2b4cd1d36f15f0bda33abf9fc59ff2 --- /dev/null +++ b/README-zh-Hans.md @@ -0,0 +1,324 @@ +# Colossal-AI +
+ + [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Colossal-AI_logo.png)](https://www.colossalai.org/) + + Colossal-AI: ไธ€ไธช้ขๅ‘ๅคงๆจกๅž‹ๆ—ถไปฃ็š„้€š็”จๆทฑๅบฆๅญฆไน ็ณป็ปŸ + +

่ฎบๆ–‡ | + ๆ–‡ๆกฃ | + ไพ‹็จ‹ | + ่ฎบๅ› | + ๅšๅฎข

+ + [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) + [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) + [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) + [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![WeChat badge](https://img.shields.io/badge/ๅพฎไฟก-ๅŠ ๅ…ฅ-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) + + | [English](README.md) | [ไธญๆ–‡](README-zh-Hans.md) | + +
+ +## ๆ–ฐ้—ป + +* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://medium.com/@yangyou_berkeley/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper-85e970fe207b) +* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://medium.com/@yangyou_berkeley/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding-10-000-4c8f0a389cd) +* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://medium.com/@yangyou_berkeley/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for-6b4c3aba07a8) +* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://medium.com/@hpcaitech/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the-892468cc2b02) +* [2022/07] [Colossal-AI Seamlessly Accelerates Large Models at Low Costs with Hugging Face](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d) + + +## ็›ฎๅฝ• + + +## ไธบไฝ•้€‰ๆ‹ฉ Colossal-AI +
+ + + + + James Demmel ๆ•™ๆŽˆ (ๅŠ ๅทžๅคงๅญฆไผฏๅ…‹ๅˆฉๅˆ†ๆ ก): Colossal-AI ่ฎฉๅˆ†ๅธƒๅผ่ฎญ็ปƒ้ซ˜ๆ•ˆใ€ๆ˜“็”จใ€ๅฏๆ‰ฉๅฑ•ใ€‚ +
+ +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## ็‰น็‚น + +Colossal-AI ไธบๆ‚จๆไพ›ไบ†ไธ€็ณปๅˆ—ๅนถ่กŒ็ป„ไปถใ€‚ๆˆ‘ไปฌ็š„็›ฎๆ ‡ๆ˜ฏ่ฎฉๆ‚จ็š„ๅˆ†ๅธƒๅผ AI ๆจกๅž‹ๅƒๆž„ๅปบๆ™ฎ้€š็š„ๅ• GPU ๆจกๅž‹ไธ€ๆ ท็ฎ€ๅ•ใ€‚ๆˆ‘ไปฌๆไพ›็š„ๅ‹ๅฅฝๅทฅๅ…ทๅฏไปฅ่ฎฉๆ‚จๅœจๅ‡ ่กŒไปฃ็ ๅ†…ๅฟซ้€Ÿๅผ€ๅง‹ๅˆ†ๅธƒๅผ่ฎญ็ปƒๅ’ŒๆŽจ็†ใ€‚ + +- ๅนถ่กŒๅŒ–็ญ–็•ฅ + - ๆ•ฐๆฎๅนถ่กŒ + - ๆตๆฐด็บฟๅนถ่กŒ + - 1็ปด, [2็ปด](https://arxiv.org/abs/2104.05343), [2.5็ปด](https://arxiv.org/abs/2105.14500), [3็ปด](https://arxiv.org/abs/2105.14450) ๅผ ้‡ๅนถ่กŒ + - [ๅบๅˆ—ๅนถ่กŒ](https://arxiv.org/abs/2105.13120) + - [้›ถๅ†—ไฝ™ไผ˜ๅŒ–ๅ™จ (ZeRO)](https://arxiv.org/abs/1910.02054) +- ๅผ‚ๆž„ๅ†…ๅญ˜็ฎก็† + - [PatrickStar](https://arxiv.org/abs/2108.05818) +- ไฝฟ็”จๅ‹ๅฅฝ + - ๅŸบไบŽๅ‚ๆ•ฐๆ–‡ไปถ็š„ๅนถ่กŒๅŒ– +- ๆŽจ็† + - [Energon-AI](https://github.com/hpcaitech/EnergonAI) +- Colossal-AI ๆˆๅŠŸๆกˆไพ‹ + - ็”Ÿ็‰ฉๅŒป่ฏ: [FastFold](https://github.com/hpcaitech/FastFold) ๅŠ ้€Ÿ่›‹็™ฝ่ดจ็ป“ๆž„้ข„ๆต‹ AlphaFold ่ฎญ็ปƒไธŽๆŽจ็† +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## ๅนถ่กŒ่ฎญ็ปƒๆ ทไพ‹ๅฑ•็คบ +### ViT +

+ +

+ +- 14ๅ€ๆ‰นๅคงๅฐๅ’Œ5ๅ€่ฎญ็ปƒ้€Ÿๅบฆ๏ผˆๅผ ้‡ๅนถ่กŒ=64๏ผ‰ + +### GPT-3 +

+ +

+ +- ้‡Šๆ”พ 50% GPU ่ต„ๆบๅ ็”จ, ๆˆ– 10.7% ๅŠ ้€Ÿ + +### GPT-2 + + +- ้™ไฝŽ11ๅ€ GPU ๆ˜พๅญ˜ๅ ็”จ๏ผŒๆˆ–่ถ…็บฟๆ€งๆ‰ฉๅฑ•๏ผˆๅผ ้‡ๅนถ่กŒ๏ผ‰ + + + +- ็”จ็›ธๅŒ็š„็กฌไปถ่ฎญ็ปƒ24ๅ€ๅคง็š„ๆจกๅž‹ +- ่ถ…3ๅ€็š„ๅžๅ้‡ + +### BERT + + +- 2ๅ€่ฎญ็ปƒ้€Ÿๅบฆ๏ผŒๆˆ–1.5ๅ€ๅบๅˆ—้•ฟๅบฆ + +### PaLM +- [PaLM-colossalai](https://github.com/hpcaitech/PaLM-colossalai): ๅฏๆ‰ฉๅฑ•็š„่ฐทๆญŒ Pathways Language Model ([PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html)) ๅฎž็Žฐใ€‚ + +### OPT + + +- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), ็”ฑMetaๅ‘ๅธƒ็š„1750ไบฟ่ฏญ่จ€ๆจกๅž‹๏ผŒ็”ฑไบŽๅฎŒๅ…จๅ…ฌๅผ€ไบ†้ข„่ฎญ็ปƒๅ‚ๆ•ฐๆƒ้‡๏ผŒๅ› ๆญคไฟƒ่ฟ›ไบ†ไธ‹ๆธธไปปๅŠกๅ’Œๅบ”็”จ้ƒจ็ฝฒ็š„ๅ‘ๅฑ•ใ€‚ +- ๅŠ ้€Ÿ45%๏ผŒไป…็”จๅ‡ ่กŒไปฃ็ ไปฅไฝŽๆˆๆœฌๅพฎ่ฐƒOPTใ€‚[[ๆ ทไพ‹]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[ๅœจ็บฟๆŽจ็†]](https://service.colossalai.org/opt) + +่ฏท่ฎฟ้—ฎๆˆ‘ไปฌ็š„ [ๆ–‡ๆกฃ](https://www.colossalai.org/) ๅ’Œ [ไพ‹็จ‹](https://github.com/hpcaitech/ColossalAI-Examples) ไปฅไบ†่งฃ่ฏฆๆƒ…ใ€‚ + + +### ๆŽจ่็ณป็ปŸๆจกๅž‹ +- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), ไฝฟ็”จ่ฝฏไปถCacheๅฎž็ŽฐEmbeddings๏ผŒ็”จๆ›ดๅฐ‘GPUๆ˜พๅญ˜่ฎญ็ปƒๆ›ดๅคง็š„ๆจกๅž‹ใ€‚ + + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## ๅ•GPU่ฎญ็ปƒๆ ทไพ‹ๅฑ•็คบ + +### GPT-2 +

+ +

+ +- ็”จ็›ธๅŒ็š„็กฌไปถ่ฎญ็ปƒ20ๅ€ๅคง็š„ๆจกๅž‹ + +

+ +

+ +- ็”จ็›ธๅŒ็š„็กฌไปถ่ฎญ็ปƒ120ๅ€ๅคง็š„ๆจกๅž‹ (RTX 3080) + +### PaLM +

+ +

+ +- ็”จ็›ธๅŒ็š„็กฌไปถ่ฎญ็ปƒ34ๅ€ๅคง็š„ๆจกๅž‹ + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ + +## ๆŽจ็† (Energon-AI) ๆ ทไพ‹ๅฑ•็คบ + +

+ +

+ +- [Energon-AI](https://github.com/hpcaitech/EnergonAI) ๏ผš็”จ็›ธๅŒ็š„็กฌไปถๆŽจ็†ๅŠ ้€Ÿ50% + +

+ +

+ +- [OPTๆŽจ็†ๆœๅŠก](https://service.colossalai.org/opt): ๆ— ้œ€ๆณจๅ†Œ๏ผŒๅ…่ดนไฝ“้ชŒ1750ไบฟๅ‚ๆ•ฐOPTๅœจ็บฟๆŽจ็†ๆœๅŠก + + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## Colossal-AI ๆˆๅŠŸๆกˆไพ‹ + +### AIGC +ๅŠ ้€ŸAIGC(AIๅ†…ๅฎน็”Ÿๆˆ)ๆจกๅž‹๏ผŒๅฆ‚[Stable Diffusion](https://github.com/CompVis/stable-diffusion) +

+ +

+ +- [Colossal-AIไผ˜ๅŒ–Stable Diffusion](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): 6.5ๅ€่ฎญ็ปƒๅŠ ้€Ÿๅ’Œ้ข„่ฎญ็ปƒๆˆๆœฌ้™ไฝŽ, ๅพฎ่ฐƒ็กฌไปถๆˆๆœฌไธ‹้™็บฆ7ๅ€(ไปŽRTX3090/4090ๅˆฐRTX3050/2070) + +

+ +

+ +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +### ็”Ÿ็‰ฉๅŒป่ฏ + +ๅŠ ้€Ÿ [AlphaFold](https://alphafold.ebi.ac.uk/) ่›‹็™ฝ่ดจ็ป“ๆž„้ข„ๆต‹ + +

+ +

+ +- [FastFold](https://github.com/hpcaitech/FastFold): ๅŠ ้€ŸAlphaFold่ฎญ็ปƒไธŽๆŽจ็†ใ€ๆ•ฐๆฎๅ‰ๅค„็†ใ€ๆŽจ็†ๅบๅˆ—้•ฟๅบฆ่ถ…่ฟ‡10000ๆฎ‹ๅŸบ + +

+ +

+ +- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): 11ๅ€ๅŠ ้€Ÿ่›‹็™ฝ่ดจๅ•ไฝ“ไธŽๅคๅˆ็‰ฉ็ป“ๆž„้ข„ๆต‹ + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## ๅฎ‰่ฃ… + +### ไปŽๅฎ˜ๆ–นๅฎ‰่ฃ… + +ๆ‚จๅฏไปฅ่ฎฟ้—ฎๆˆ‘ไปฌ[ไธ‹่ฝฝ](https://www.colossalai.org/download)้กต้ขๆฅๅฎ‰่ฃ…Colossal-AI๏ผŒๅœจ่ฟ™ไธช้กต้ขไธŠๅ‘ๅธƒ็š„็‰ˆๆœฌ้ƒฝ้ข„็ผ–่ฏ‘ไบ†CUDAๆ‰ฉๅฑ•ใ€‚ + +### ไปŽๆบๅฎ‰่ฃ… + +> ๆญคๆ–‡ๆกฃๅฐ†ไธŽ็‰ˆๆœฌๅบ“็š„ไธปๅˆ†ๆ”ฏไฟๆŒไธ€่‡ดใ€‚ๅฆ‚ๆžœๆ‚จ้‡ๅˆฐไปปไฝ•้—ฎ้ข˜๏ผŒๆฌข่ฟŽ็ป™ๆˆ‘ไปฌๆ issue :) + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install dependency +pip install -r requirements/requirements.txt + +# install colossalai +pip install . +``` + +ๅฆ‚ๆžœๆ‚จไธๆƒณๅฎ‰่ฃ…ๅ’Œๅฏ็”จ CUDA ๅ†…ๆ ธ่žๅˆ๏ผˆไฝฟ็”จ่žๅˆไผ˜ๅŒ–ๅ™จๆ—ถๅผบๅˆถๅฎ‰่ฃ…๏ผ‰๏ผš + +```shell +NO_CUDA_EXT=1 pip install . +``` + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## ไฝฟ็”จ Docker + +### ไปŽDockerHub่Žทๅ–้•œๅƒ + +ๆ‚จๅฏไปฅ็›ดๆŽฅไปŽๆˆ‘ไปฌ็š„[DockerHubไธป้กต](https://hub.docker.com/r/hpcaitech/colossalai)่Žทๅ–ๆœ€ๆ–ฐ็š„้•œๅƒ๏ผŒๆฏไธ€ๆฌกๅ‘ๅธƒๆˆ‘ไปฌ้ƒฝไผš่‡ชๅŠจไธŠไผ ๆœ€ๆ–ฐ็š„้•œๅƒใ€‚ + +### ๆœฌๅœฐๆž„ๅปบ้•œๅƒ + +่ฟ่กŒไปฅไธ‹ๅ‘ฝไปคไปŽๆˆ‘ไปฌๆไพ›็š„ docker ๆ–‡ไปถไธญๅปบ็ซ‹ docker ้•œๅƒใ€‚ + +> ๅœจDockerfile้‡Œ็ผ–่ฏ‘Colossal-AI้œ€่ฆๆœ‰GPUๆ”ฏๆŒ๏ผŒๆ‚จ้œ€่ฆๅฐ†Nvidia Docker Runtime่ฎพ็ฝฎไธบ้ป˜่ฎค็š„Runtimeใ€‚ๆ›ดๅคšไฟกๆฏๅฏไปฅ็‚นๅ‡ป[่ฟ™้‡Œ](https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime)ใ€‚ +> ๆˆ‘ไปฌๆŽจ่ไปŽ[้กน็›ฎไธป้กต](https://www.colossalai.org)็›ดๆŽฅไธ‹่ฝฝColossal-AI. + +```bash +cd ColossalAI +docker build -t colossalai ./docker +``` + +่ฟ่กŒไปฅไธ‹ๅ‘ฝไปคไปŽไปฅไบคไบ’ๅผๅฏๅŠจ docker ้•œๅƒ. + +```bash +docker run -ti --gpus all --rm --ipc=host colossalai bash +``` + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ +## ็คพๅŒบ +ๆฌข่ฟŽ้€š่ฟ‡[่ฎบๅ›](https://github.com/hpcaitech/ColossalAI/discussions), +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +ๆˆ–[ๅพฎไฟก](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode")ๅŠ ๅ…ฅ Colossal-AI ็คพๅŒบ๏ผŒไธŽๆˆ‘ไปฌๅˆ†ไบซไฝ ็š„ๅปบ่ฎฎๅ’Œ้—ฎ้ข˜ใ€‚ + + +## ๅšๅ‡บ่ดก็Œฎ + +ๆฌข่ฟŽไธบ่ฏฅ้กน็›ฎๅšๅ‡บ่ดก็Œฎ๏ผŒ่ฏทๅ‚้˜…[่ดก็ŒฎๆŒ‡ๅ—](./CONTRIBUTING.md)ใ€‚ + +็œŸ่ฏšๆ„Ÿ่ฐขๆ‰€ๆœ‰่ดก็Œฎ่€…๏ผ + + + +*่ดก็Œฎ่€…ๅคดๅƒ็š„ๅฑ•็คบ้กบๅบๆ˜ฏ้šๆœบ็š„ใ€‚* + +

(่ฟ”ๅ›ž้กถ็ซฏ)

+ + +## ๅผ•็”จๆˆ‘ไปฌ + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +

(่ฟ”ๅ›ž้กถ็ซฏ)

\ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f27680d8c5ab058b878683c0f796a7fa2ebe36aa --- /dev/null +++ b/README.md @@ -0,0 +1,329 @@ +# Colossal-AI +
+ + [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Colossal-AI_logo.png)](https://www.colossalai.org/) + + Colossal-AI: A Unified Deep Learning System for Big Model Era + +

Paper | + Documentation | + Examples | + Forum | + Blog

+ + [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) + [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) + [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) + [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![WeChat badge](https://img.shields.io/badge/ๅพฎไฟก-ๅŠ ๅ…ฅ-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) + + + | [English](README.md) | [ไธญๆ–‡](README-zh-Hans.md) | + +
+ +## Latest News + +* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://medium.com/@yangyou_berkeley/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper-85e970fe207b) +* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://medium.com/@yangyou_berkeley/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding-10-000-4c8f0a389cd) +* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://medium.com/@yangyou_berkeley/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for-6b4c3aba07a8) +* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://medium.com/@hpcaitech/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the-892468cc2b02) +* [2022/07] [Colossal-AI Seamlessly Accelerates Large Models at Low Costs with Hugging Face](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d) + +## Table of Contents + + +## Why Colossal-AI +
+ + + + + Prof. James Demmel (UC Berkeley): Colossal-AI makes training AI models efficient, easy, and scalable. +
+ +

(back to top)

+ +## Features + +Colossal-AI provides a collection of parallel components for you. We aim to support you to write your +distributed deep learning models just like how you write your model on your laptop. We provide user-friendly tools to kickstart +distributed training and inference in a few lines. + +- Parallelism strategies + - Data Parallelism + - Pipeline Parallelism + - 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism + - [Sequence Parallelism](https://arxiv.org/abs/2105.13120) + - [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054) + +- Heterogeneous Memory Management + - [PatrickStar](https://arxiv.org/abs/2108.05818) + +- Friendly Usage + - Parallelism based on configuration file + +- Inference + - [Energon-AI](https://github.com/hpcaitech/EnergonAI) + +- Colossal-AI in the Real World + - Biomedicine: [FastFold](https://github.com/hpcaitech/FastFold) accelerates training and inference of AlphaFold protein structure +

(back to top)

+ +## Parallel Training Demo +### ViT +

+ +

+ +- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64 + +### GPT-3 +

+ +

+ +- Save 50% GPU resources, and 10.7% acceleration + +### GPT-2 + + +- 11x lower GPU memory consumption, and superlinear scaling efficiency with Tensor Parallelism + + + +- 24x larger model size on the same hardware +- over 3x acceleration +### BERT + + +- 2x faster training, or 50% longer sequence length + +### PaLM +- [PaLM-colossalai](https://github.com/hpcaitech/PaLM-colossalai): Scalable implementation of Google's Pathways Language Model ([PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html)). + +### OPT + + +- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights. +- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt) + +Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details. + +### Recommendation System Models +- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), utilize software cache to train larger embedding tables with a smaller GPU memory budget. + +

(back to top)

+ +## Single GPU Training Demo + +### GPT-2 +

+ +

+ +- 20x larger model size on the same hardware + +

+ +

+ +- 120x larger model size on the same hardware (RTX 3080) + +### PaLM +

+ +

+ +- 34x larger model size on the same hardware + +

(back to top)

+ + +## Inference (Energon-AI) Demo + +

+ +

+ +- [Energon-AI](https://github.com/hpcaitech/EnergonAI): 50% inference acceleration on the same hardware + +

+ +

+ +- [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever. + +

(back to top)

+ +## Colossal-AI in the Real World + +### AIGC +Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion](https://github.com/CompVis/stable-diffusion) +

+ +

+ +- [Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): 6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper (from RTX3090/4090 to RTX3050/2070) + +

+ +

+ +

(back to top)

+ +### Biomedicine +Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/) + +

+ +

+ +- [FastFold](https://github.com/hpcaitech/FastFold): accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues. + +

+ +

+ +- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x. + + +

(back to top)

+ +## Installation + +### Download From Official Releases + +You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions. + + +### Download From Source + +> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. :) + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install dependency +pip install -r requirements/requirements.txt + +# install colossalai +pip install . +``` + +If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): + +```shell +NO_CUDA_EXT=1 pip install . +``` + +

(back to top)

+ +## Use Docker + +### Pull from DockerHub + +You can directly pull the docker image from our [DockerHub page](https://hub.docker.com/r/hpcaitech/colossalai). The image is automatically uploaded upon release. + + +### Build On Your Own + +Run the following command to build a docker image from Dockerfile provided. + +> Building Colossal-AI from scratch requires GPU support, you need to use Nvidia Docker Runtime as the default when doing `docker build`. More details can be found [here](https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime). +> We recommend you install Colossal-AI from our [project page](https://www.colossalai.org) directly. + + +```bash +cd ColossalAI +docker build -t colossalai ./docker +``` + +Run the following command to start the docker container in interactive mode. + +```bash +docker run -ti --gpus all --rm --ipc=host colossalai bash +``` + +

(back to top)

+ +## Community + +Join the Colossal-AI community on [Forum](https://github.com/hpcaitech/ColossalAI/discussions), +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your suggestions, feedback, and questions with our engineering team. + +## Contributing + +If you wish to contribute to this project, please follow the guideline in [Contributing](./CONTRIBUTING.md). + +Thanks so much to all of our amazing contributors! + + + +*The order of contributor avatars is randomly shuffled.* + +

(back to top)

+ + +## Cite Us + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +

(back to top)

diff --git a/colossalai/_C/__init__.pyi b/colossalai/_C/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bfd86d0ee01d41871e535b493dcf5aecb0c83279 --- /dev/null +++ b/colossalai/_C/__init__.pyi @@ -0,0 +1,9 @@ +from . import ( + cpu_optim, + fused_optim, + layer_norm, + moe, + multihead_attention, + scaled_masked_softmax, + scaled_upper_triang_masked_softmax, +) diff --git a/colossalai/_C/cpu_optim.pyi b/colossalai/_C/cpu_optim.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0f7611790291291ed1f5abcc54197fcce1d54bea --- /dev/null +++ b/colossalai/_C/cpu_optim.pyi @@ -0,0 +1,8 @@ +from torch import Tensor + +class CPUAdamOptimizer: + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, + weight_decay: float, adamw_mode: float) -> None: ... + + def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool, + param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ... diff --git a/colossalai/_C/fused_optim.pyi b/colossalai/_C/fused_optim.pyi new file mode 100644 index 0000000000000000000000000000000000000000..983b02335b41dbda8e45f128bc55ff5e2cad4822 --- /dev/null +++ b/colossalai/_C/fused_optim.pyi @@ -0,0 +1,23 @@ +from typing import List + +from torch import Tensor + +def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None: + ... + + +def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float, + momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None: + ... + + +def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None: + ... + + +def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None: + ... + + +def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None: + ... diff --git a/colossalai/_C/layer_norm.pyi b/colossalai/_C/layer_norm.pyi new file mode 100644 index 0000000000000000000000000000000000000000..02d4587ff05e511f87c67d58e4d42a7a7368806c --- /dev/null +++ b/colossalai/_C/layer_norm.pyi @@ -0,0 +1,11 @@ +from typing import List + +from torch import Tensor + +def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]: + ... + + +def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor, + normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]: + ... diff --git a/colossalai/_C/moe.pyi b/colossalai/_C/moe.pyi new file mode 100644 index 0000000000000000000000000000000000000000..121aa7e41082b32b033c7cc7e2b13ea746f20f50 --- /dev/null +++ b/colossalai/_C/moe.pyi @@ -0,0 +1,20 @@ +from torch import Tensor + +def cumsum_sub_one(mask: Tensor) -> Tensor: + ... + + +def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... + + +def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... + + +def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... + + +def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... diff --git a/colossalai/_C/multihead_attention.pyi b/colossalai/_C/multihead_attention.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7ad87ea9a624f4d85b66e085f4a96c72c366ad47 --- /dev/null +++ b/colossalai/_C/multihead_attention.pyi @@ -0,0 +1,55 @@ +from typing import List + +from torch import Tensor +from torch.distributed import ProcessGroup + +def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor, + in_proj_weight: Tensor, in_proj_bias: Tensor, + out_proj_weight: Tensor, out_proj_bias: Tensor, + norm_weight: Tensor, norm_bias: Tensor, + training_mode: bool, prelayernorm: bool) -> List[Tensor]: + ... + + +def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor, + in_proj_weight: Tensor, in_proj_bias: Tensor, + out_proj_weight: Tensor, out_proj_bias: Tensor, + norm_weight: Tensor, norm_bias: Tensor, + training_mode: bool, prelayernorm: bool) -> List[Tensor]: + ... + + +def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor, + output: Tensor, input: Tensor, + input_mask: Tensor, in_proj_weight: Tensor, + in_proj_bias: Tensor, out_proj_weight: Tensor, + out_proj_bias: Tensor, norm_weight: Tensor, + norm_bias: Tensor) -> List[Tensor]: + ... + + +def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor, + output: Tensor, input: Tensor, + input_mask: Tensor, in_proj_weight: Tensor, + in_proj_bias: Tensor, out_proj_weight: Tensor, + out_proj_bias: Tensor, norm_weight: Tensor, + norm_bias: Tensor) -> List[Tensor]: + ... + + +def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int, + max_seq_len: int, hidden_dim: int, num_heads: int, + attn_prob_dropout_ratio: float, + hidden_dropout_ratio: float, + pre_or_postLayerNorm: bool, + pg: ProcessGroup) -> int: + ... + + +def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int, + max_seq_len: int, hidden_dim: int, num_heads: int, + attn_prob_dropout_ratio: float, + hidden_dropout_ratio: float, + pre_or_postLayerNorm: bool, + pg: ProcessGroup) -> int: + ... diff --git a/colossalai/_C/scaled_masked_softmax.pyi b/colossalai/_C/scaled_masked_softmax.pyi new file mode 100644 index 0000000000000000000000000000000000000000..fdb88266ef0bfb754f568c771f1f3fd2f8f32330 --- /dev/null +++ b/colossalai/_C/scaled_masked_softmax.pyi @@ -0,0 +1,12 @@ +from torch import Tensor + +def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor: + ... + + +def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor: + ... + + +def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int: + ... diff --git a/colossalai/_C/scaled_upper_triang_masked_softmax.pyi b/colossalai/_C/scaled_upper_triang_masked_softmax.pyi new file mode 100644 index 0000000000000000000000000000000000000000..39a3d6b2299bcb4409ad710f448e15f7ecc50792 --- /dev/null +++ b/colossalai/_C/scaled_upper_triang_masked_softmax.pyi @@ -0,0 +1,8 @@ +from torch import Tensor + +def forward(input: Tensor, scale: float) -> Tensor: + ... + + +def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor: + ... diff --git a/colossalai/__init__.py b/colossalai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f859161f78108e4c8b5cfba89e12026fee28da43 --- /dev/null +++ b/colossalai/__init__.py @@ -0,0 +1,17 @@ +from .initialize import ( + get_default_parser, + initialize, + launch, + launch_from_openmpi, + launch_from_slurm, + launch_from_torch, +) + +try: + # .version will be created by setup.py + from .version import __version__ +except ModuleNotFoundError: + # this will only happen if the user did not run `pip install` + # and directly set PYTHONPATH to use Colossal-AI which is a bad practice + __version__ = '0.0.0' + print('please install Colossal-AI from https://www.colossalai.org/download or from source') diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16da81f23898af853f19134b19f2dcb8eb18f615 --- /dev/null +++ b/colossalai/amp/__init__.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from .amp_type import AMP_TYPE +from colossalai.context import Config +import torch.nn as nn +from torch.optim import Optimizer +from torch.nn.modules.loss import _Loss +from .torch_amp import convert_to_torch_amp +from .apex_amp import convert_to_apex_amp +from .naive_amp import convert_to_naive_amp + +__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] + + +def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): + """A helper function to wrap training components with Torch AMP modules. + + Args: + param model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object. + mode (:class:`colossalai.amp.AMP_TYPE`): amp mode. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes. + + Returns: + A tuple (model, optimizer, criterion). + + Note: + ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode + for more details about ``amp_config``. + For ``apex_amp``, please check + `apex_amp config `_. + For ``naive_amp``, please check + `naive_amp config `_. + For ``torch_amp``, please check + `torch_amp config `_. + """ + assert isinstance(mode, AMP_TYPE), \ + f'expected the argument mode be AMP_TYPE, but got {type(mode)}' + + if amp_config is None: + amp_config = Config() + + if mode == AMP_TYPE.TORCH: + model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config) + elif mode == AMP_TYPE.APEX: + model, optimizer = convert_to_apex_amp(model, optimizer, amp_config) + elif mode == AMP_TYPE.NAIVE: + model, optimizer = convert_to_naive_amp(model, optimizer, amp_config) + + return model, optimizer, criterion diff --git a/colossalai/amp/amp_type.py b/colossalai/amp/amp_type.py new file mode 100644 index 0000000000000000000000000000000000000000..6f322f866cfc813e66e54b0c1006d62ef949e96e --- /dev/null +++ b/colossalai/amp/amp_type.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +class AMP_TYPE(Enum): + APEX = 'apex' + TORCH = 'torch' + NAIVE = 'naive' diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51b9b97dccce877783251fb3f61f08a87a6a7659 --- /dev/null +++ b/colossalai/amp/apex_amp/__init__.py @@ -0,0 +1,42 @@ +import torch.nn as nn +from torch.optim import Optimizer + +from .apex_amp import ApexAMPOptimizer + + +def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): + r"""A helper function to wrap training components with Apex AMP modules + + Args: + model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp. + + Returns: + Tuple: A tuple (model, optimizer). + + The ``amp_config`` should include parameters below: + :: + + enabled (bool, optional, default=True) + opt_level (str, optional, default="O1") + cast_model_type (``torch.dtype``, optional, default=None) + patch_torch_functions (bool, optional, default=None) + keep_batchnorm_fp32 (bool or str, optional, default=None + master_weights (bool, optional, default=None) + loss_scale (float or str, optional, default=None) + cast_model_outputs (torch.dtype, optional, default=None) + num_losses (int, optional, default=1) + verbosity (int, default=1) + min_loss_scale (float, default=None) + max_loss_scale (float, default=2.**24) + + More details about ``amp_config`` refer to `amp_config `_. + """ + import apex.amp as apex_amp + model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) + optimizer = ApexAMPOptimizer(optimizer) + return model, optimizer + + +__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer'] diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/amp/apex_amp/apex_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..69a4e348e5a7250ad6067199c660bf70cd2a2621 --- /dev/null +++ b/colossalai/amp/apex_amp/apex_amp.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn +try: + import apex.amp as apex_amp +except ImportError: + pass + +from torch import Tensor + +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import clip_grad_norm_fp32 + + +class ApexAMPOptimizer(ColossalaiOptimizer): + """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm + methods + """ + + def backward(self, loss: Tensor): + """Backward pass to get all gradients + + Args: + loss (torch.Tensor): Loss computed by a loss function + """ + with apex_amp.scale_loss(loss, self.optim) as scaled_loss: + scaled_loss.backward() + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + """Clip gradients by norm + + Args: + model (torch.nn.Module): Your model object + max_norm (float): The max norm value for gradient clipping + """ + if max_norm > 0: + clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm) diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b2f71d3ced771c43d541843153c6b64613f69e1 --- /dev/null +++ b/colossalai/amp/naive_amp/__init__.py @@ -0,0 +1,60 @@ +import inspect + +import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.utils import is_no_pp_or_last_stage + +from ._fp16_optimizer import FP16Optimizer +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer + + +def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): + """A helper function to wrap training components with naive AMP modules. In this mode, + we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, + which is equivalent to Apex O3. + + Args: + model (:class:`torch.nn.Module`): your model object + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp. + + Returns: + Tuple: A tuple (model, optimizer) + + The ``amp_config`` should contain parameters below:: + + verbose (bool, optional): if set to `True`, will print debug info (Default: False). + clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). + Note that clipping is ignored if clip_grad == 0. + dynamic_grad_scale (bool): whether to use dynamic grad scaler. + """ + if isinstance(model, nn.ModuleList): + # interleaved pipeline + module_list = [] + for chunk, m in enumerate(model): + output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 + module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) + model = nn.ModuleList(module_list) + else: + output_to_fp32 = is_no_pp_or_last_stage() + model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) + + use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) + if use_dynamic_grad_scaler: + scaler_class = DynamicGradScaler + else: + scaler_class = ConstantGradScaler + + sig = inspect.signature(scaler_class.__init__) + kwargs = dict() + for param in sig.parameters.values(): + if param.name in amp_config: + kwargs[param.name] = amp_config.pop(param.name) + grad_scaler = scaler_class(**kwargs) + optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) + return model, optimizer + + +__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8be009bf4a2ea5bc5b32c613d619160dc70edd --- /dev/null +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist + +try: + import colossalai._C.fused_optim +except: + print('Colossalai should be built with cuda extension to use the FP16 optimizer') + +from torch.distributed import ProcessGroup +from torch.optim import Optimizer + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier + +from ._utils import has_inf_or_nan, zero_gard_by_list +from .grad_scaler import BaseGradScaler + +__all__ = ['FP16Optimizer'] + + +def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): + """ + adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM) + + Use multi-tensor-applier to copy values from one list to another. + We don't have a blfoat16 implementation so for now if the overflow_buf + is not provided, we default back to simple loop copy to be compatible + with bfloat16. + """ + if overflow_buf: + overflow_buf.fill_(0) + # Scaling with factor `1.0` is equivalent to copy. + multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0) + else: + for this_, that_ in zip(this, that): + that_.copy_(this_) + + +class FP16Optimizer(Optimizer): + """Float16 optimizer for fp16 and bf16 data types. + + Args: + optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD + grad_scaler (BaseGradScaler): grad scaler for gradient chose in + ``constant_grad_scaler`` or ``dynamic_grad_scaler``. + clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0. + Note that clipping is ignored if clip_grad == 0 + verbose (bool, optional): if set to `True`, will print debug info. Default False. + """ + + def __init__(self, + optimizer: Optimizer, + grad_scaler: BaseGradScaler, + verbose: bool = False, + clip_grad_norm=0, + dp_process_group: ProcessGroup = None, + mp_process_group: ProcessGroup = None): + # have a defaults for compatibility with pytorch optim + self._optimizer = optimizer + self._defaults = optimizer.defaults + + # fp16-related params + assert isinstance(grad_scaler, BaseGradScaler) + self._grad_scaler = grad_scaler + self._found_overflow = torch.cuda.FloatTensor([0.0]) + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + # misc params + self._clip_grad_max_norm = clip_grad_norm + + # get process group + def _get_process_group(parallel_mode): + if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA): + return gpc.get_group(ParallelMode.DATA) + else: + return None + + if dp_process_group is None: + dp_process_group = _get_process_group(ParallelMode.DATA) + if mp_process_group is None: + mp_process_group = _get_process_group(ParallelMode.MODEL) + + self._dp_process_group = dp_process_group + self._mp_process_group = mp_process_group + + # we maintain three groups of parameters + # so that the model can have a mixture + # of fp16 and fp32 params + # fp16_param_groups: the fp16 params of the model + # fp32_master_param_groups: the fp32 params cast from the fp16 param of the model + # fp32_param_groups: the fp32 params of the model + # NOTE: + # 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence + # 2. fp32_param_groups and fp16_param_groups are exclusive of each other + self._fp16_param_groups = [] + self._fp32_master_param_groups = [] + self._fp32_param_groups = [] + + # For all the groups in the original optimizer: + for param_group in self._optimizer.param_groups: + fp16_params = [] + fp32_master_params = [] + fp32_params = [] + # For all the parameters in this group: + for i, param in enumerate(param_group['params']): + if param.requires_grad: + # float16 params: + if param.type() in ['torch.cuda.HalfTensor']: + fp16_params.append(param) + + # Create a fp32 copy + fp32_param = param.detach().clone().float() + # Copy tensor model parallel attributes. + copy_tensor_parallel_attributes(param, fp32_param) + + # Replace the optimizer params with the new fp32 copy. + param_group['params'][i] = fp32_param + fp32_master_params.append(fp32_param) + + # Reset existing state dict key to the new main param. + if param in self._optimizer.state: + self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) + + # fp32 params. + elif param.type() == 'torch.cuda.FloatTensor': + fp32_params.append(param) + else: + raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' + f'or torch.cuda.HalfTensor, but got {param.type()}') + + self._fp16_param_groups.append(fp16_params) + self._fp32_master_param_groups.append(fp32_master_params) + self._fp32_param_groups.append(fp32_params) + + # Leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors + self._optimizer.load_state_dict(self._optimizer.state_dict()) + + # log config + self._logger = get_dist_logger() + if verbose: + self._logger.info( + f"\n========= FP16 Optimizer Config =========\n" + f"Optimizer: {optimizer.__class__.__name__}\n" + f"clip_grad_norm = {clip_grad_norm}\n" + f"grad_scaler = {self._grad_scaler.__class__.__name__}" + f"==========================================", + ranks=[0]) + + @property + def grad_scaler(self): + """Returns the gradient scaler. + + Returns: + :class:`BaseGradScaler`: gradient scaler. + """ + + return self._grad_scaler + + @property + def loss_scale(self): + """Returns the loss scale. + + Returns: + int: loss scale. + """ + return self._grad_scaler.scale + + @property + def optimizer(self): + """Returns the optimizer. + + Returns: + :class:`torch.optim.Optimizer`: the optimizer object wrapped. + """ + return self._optimizer + + @property + def defaults(self): + """Returns the default arguments of optimizer. + + Returns: + dict: optimizer arguments saved in defaults of the optimizer wrapped. + """ + return self._defaults + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group in self._optimizer.param_groups: + for p in group['params']: + if p.grad is not None and has_inf_or_nan(p.grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + if self._dp_process_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group) + + # all-reduce over model parallel group + if self._mp_process_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group) + + return self._found_overflow.item() > 0 + + def zero_grad(self, set_to_none=True): + """Set gradient to zero. + + Args: + set_to_none (bool): Whether set the gradient to None. + """ + + # set_to_none = True can save some memory space + for param_group in self._optimizer.param_groups: + zero_gard_by_list(param_group['params'], set_to_none=set_to_none) + + def _get_fp32_param_groups_to_update(self): + return self._fp32_master_param_groups + self._fp32_param_groups + + def _unscale_grads(self): + for group in self._get_fp32_param_groups_to_update(): + for p in group: + if p.grad is not None: + p.grad.data.div_(self.loss_scale) + + def _assign_grad_to_fp32_master_param(self): + # This only needs to be done for the float16 group. + for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): + for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group): + if fp16_param.grad is not None: + fp32_param.grad = fp16_param.grad.float() + # clear unneeded grad on fp16 param + fp16_param.grad = None + + def _update_fp16_param_from_fp32_param(self): + fp16_param_data = [] + fp32_master_param_data = [] + for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + fp16_param_data.append(fp16_param.data) + fp32_master_param_data.append(fp32_param.data) + _multi_tensor_copy_this_to_that(this=fp32_master_param_data, + that=fp16_param_data, + overflow_buf=self._dummy_overflow_buf) + + def step(self): + """Update the model parameters. + """ + + # Copy gradients from model params to main params. + self._assign_grad_to_fp32_master_param() + self._unscale_grads() + + overflow = self._check_overflow() + self._grad_scaler.update(overflow) + if overflow: + self.zero_grad() + + # Clip the main gradients. + grad_norm = None + if self._clip_grad_max_norm > 0.0: + grad_norm = self.clip_grad_norm(self._clip_grad_max_norm) + + if not overflow: + # Step the optimizer. + self._optimizer.step() + + # Update params from main params. + self._update_fp16_param_from_fp32_param() + + # Successful update. + return True, grad_norm + else: + return False, None + + def backward(self, loss): + """Execute backward pass. + + Args: + loss (:class:`torch.Tensor`): the loss value. + """ + + scaled_loss = loss * self.grad_scaler.scale + scaled_loss.backward() + + def state_dict(self): + """Returns the states of the fp16 optimizer as a dict object. + """ + + state_dict = {} + state_dict['optimizer'] = self._optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups + return state_dict + + def load_state_dict(self, state_dict): + """Load the states of the fp16 optimizer from a dict object. + + Args: + state_dict (dict): the states of the fp16 optimizer + """ + + # Optimizer. + self._optimizer.load_state_dict(state_dict['optimizer']) + + # Grad scaler. + if 'grad_scaler' in state_dict: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + + # Copy data for the main params. + if 'fp32_master_param_groups' in state_dict: + for current_group, ckpt_group in zip(self._fp32_master_param_groups, + state_dict['fp32_master_param_groups']): + for current_param, ckpt_param in zip(current_group, ckpt_group): + current_param.data.copy_(ckpt_param.data) + + def clip_grad_norm(self, clip_grad): + """Clip gradients by norm. + + Args: + clip_grad (float): the max norm for clipping + """ + params = [] + for param_group in self._optimizer.param_groups: + for param in param_group['params']: + params.append(param) + return clip_grad_norm_fp32(params, clip_grad) + + # Promote state so it can be retrieved or set via + # "optimizer_instance.state" + def _get_state(self): + return self._optimizer.state + + def _set_state(self, value): + self._optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via + # "optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self._optimizer.param_groups + + def _set_param_groups(self, value): + self._optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7633705e19fbce24faec87f9691c834279f0d8ad --- /dev/null +++ b/colossalai/amp/naive_amp/_utils.py @@ -0,0 +1,49 @@ +from typing import List + +from torch import Tensor + + +def has_inf_or_nan(tensor): + """Check if tensor has inf or nan values. + + Args: + tensor (:class:`torch.Tensor`): a torch tensor object + + Returns: + bool: Whether the tensor has inf or nan. True for yes and False for no. + """ + try: + # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as tensor + # (which is true for some recent version of pytorch). + tensor_sum = float(tensor.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # tensor_sum = float(tensor.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + return True + return False + + +def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None: + """Clear the gradient of a list of tensors, + + Note: copied from torch.optim.optimizer. + """ + for param in tensor_list: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8499d877e13f8e0eb14317d2cf4a8d54dfcb2a --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py @@ -0,0 +1,5 @@ +from .base_grad_scaler import BaseGradScaler +from .constant_grad_scaler import ConstantGradScaler +from .dynamic_grad_scaler import DynamicGradScaler + +__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..0d84384a7f67c6a4521a86d34f71ff03b821c7be --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import Tensor + +from colossalai.logging import get_dist_logger + +__all__ = ['BaseGradScaler'] + + +class BaseGradScaler(ABC): + """A base class for the gradient scaler. + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: float, verbose: bool): + assert initial_scale > 0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + self._verbose = verbose + + if self._verbose: + self._logger = get_dist_logger() + + @property + def scale(self) -> Tensor: + """Returns the loss scale. + """ + + return self._scale + + @property + def inv_scale(self) -> Tensor: + """Returns the inverse of the loss scale. + """ + + return self._scale.double().reciprocal().float() + + def state_dict(self) -> Dict: + """Returns the states of the gradient scaler as a dict object. + """ + + state_dict = dict() + state_dict['scale'] = self.scale + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + """Load the states of the gradient scaler from a dict object. + + Args: + state_dict (dict): the states of the gradient scaler + """ + + self._scale = state_dict['scale'] + + @abstractmethod + def update(self, overflow: bool) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + + pass + + def log(self, message, *args, **kwargs): + """Log messages. + + Args: + message (str): the message to log + *args: positional arguments for :class:`colossalai.logging.DistributedLogger` + **kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger` + """ + + if self._verbose: + self._logger.info(message, *args, **kwargs) diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f518c5dd28261f98faadf9134721ef1fd67dc7 --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from .base_grad_scaler import BaseGradScaler + +__all__ = ['ConstantGradScaler'] + + +class ConstantGradScaler(BaseGradScaler): + """A gradient scaler which uses constant loss scale + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: int, verbose: bool): + super().__init__(initial_scale, verbose) + self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0]) + + def update(self, overflow: bool) -> None: + """Do nothing to keep the loss scale constant. + + Args: + overflow (bool): whether overflow occurs + """ + pass diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac26ee914aec979642b672d4ebf6ad1dca7a47c --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from .base_grad_scaler import BaseGradScaler +from typing import Optional + +__all__ = ['DynamicGradScaler'] + + +class DynamicGradScaler(BaseGradScaler): + """A gradient scaler which uses dynamic loss scale + + Args: + initial_scale (float): the initial loss scale, defaults to 2**16 + growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2 + backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5 + growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000 + min_scale (float): the minimum loss scale, defaults to None + max_scale (float): the maximum loss scale, defaults to None + hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2 + verbose (bool): whether to log messages, defaults to False + """ + + def __init__(self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False): + super().__init__(initial_scale, verbose) + if min_scale: + self._min_scale = torch.cuda.FloatTensor([min_scale]) + else: + self._min_scale = None + + if max_scale: + self._max_scale = torch.cuda.FloatTensor([max_scale]) + else: + self._max_scale = None + + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._growth_step = 0 + self._hysteresis = hysteresis + self._hysteresis_step = 0 + self._sanity_checks() + + def _sanity_checks(self) -> None: + """Check if the arguments are correct. + """ + + if self._min_scale: + assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' + if self._max_scale: + assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative' + assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' + assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1' + assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + + def update(self, overflow: bool) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + if overflow: + self._hysteresis_step += 1 + self._growth_step = 0 + + if self._hysteresis_step >= self._hysteresis: + self._backoff_scale() + self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0]) + else: + self._growth_step += 1 + if self._growth_step == self._growth_interval: + self._growth_step = 0 + self._hysteresis_step = 0 + self._grow_scale() + self.log( + f"No overflow for consecutive {self._growth_interval} steps, " + f"the loss scale is adjusted to {self.scale.item()}", + ranks=[0]) + + def _backoff_scale(self) -> None: + """Decrease the loss scale + """ + + self._scale = self._scale * self._backoff_factor + if self._min_scale: + self._scale = torch.max(self._scale, self._min_scale) + + def _grow_scale(self) -> None: + """Increase the loss scale + """ + + self._scale = self._scale * self._growth_factor + if self._max_scale: + self._scale = torch.min(self._scale, self._max_scale) diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..02eae80b9dbfb3f4b5f470872fa4750381a5853c --- /dev/null +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor +from typing import Any +from torch.optim import Optimizer +from torch.distributed import ReduceOp +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.nn.optimizer import ColossalaiOptimizer +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from ._fp16_optimizer import FP16Optimizer + + +class NaiveAMPOptimizer(ColossalaiOptimizer): + """A wrapper class for optimizer to cast all parameters to fp16 + + Args: + optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD. + grad_scaler (BaseGradScaler): grad scaler for gradient chose in + ``constant_grad_scaler`` or ``dynamic_grad_scaler``. + clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0. + verbose (bool, optional): if set to `True`, will print debug info. Default False. + + Note: + clipping is ignored if ``clip_grad_norm`` equals 0. + """ + + def __init__(self, optim: Optimizer, *args, **kwargs): + optim = FP16Optimizer(optim, *args, **kwargs) + super().__init__(optim) + + def backward(self, loss: Tensor): + self.optim.backward(loss) + + def step(self): + return self.optim.step() + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + pass + + +class NaiveAMPModel(nn.Module): + r"""A wrapper class for model to cast the model into fp16 and + automatically cast the input and output + + Args: + model (torch.nn.Module): torch.nn.Module to be wrapped. + output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True) + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module. + (Default: ``ParallelMode.DATA``) + sync_buffer (bool, optional): whether to synchronize buffer. (Default: True) + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + def __init__(self, + model: nn.Module, + output_to_fp32: bool = True, + parallel_mode: ParallelMode = ParallelMode.DATA, + sync_buffer: bool = True): + super().__init__() + self.model = model.half() + self._output_to_fp32 = output_to_fp32 + self._sync_buf = sync_buffer + + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + self._process_group = gpc.get_group(parallel_mode) + self._world_size = gpc.get_world_size(parallel_mode) + else: + self._process_group = None + self._world_size = 1 + self._sync_buf = False + self._first_eval_run = False + + @property + def sync_buffer(self): + return self._sync_buf + + @sync_buffer.setter + def sync_buffer(self, state: bool): + self._sync_buf = state + + def _convert_to_fp16(self, input_: Any): + if isinstance(input_, Tensor) and input_.dtype == torch.float32: + input_ = input_.half() + return input_ + + def _convert_to_fp32(self, input_: Any): + if isinstance(input_, Tensor) and input_.dtype == torch.float16: + input_ = input_.float() + return input_ + + def _reduce_module_buffer(self): + """ + All-reduce the buffers (e.g. running stats of batch normalization) across + data parallel ranks so that all the ranks will produce consistent results + when given the same input + """ + buf_list = [] + + # find valid buffers + for buf in self.model.buffers(): + if buf is not None: + buf_list.append(buf) + + # reduce buffers across data parallel ranks + if buf_list: + coalesced_buf = _flatten_dense_tensors(buf_list) + coalesced_buf.div_(self._world_size) + dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group) + unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list) + for old, new in zip(buf_list, unflattened_buf_list): + old.copy_(new) + + def eval(self): + self.model.eval() + + # we only sync buffer in the first eval iteration + # so that future eval iterations can be done without communication + self._first_eval_run = True + + def forward(self, *args, **kwargs): + # reduce buffers after forward will lead to error + # as we cannot change the variables needed for gradient computation after forward + # so we sync buffer before forward + if (self.training or self._first_eval_run) and self._sync_buf: + with torch.no_grad(): + self._reduce_module_buffer() + + if self._first_eval_run: + self._first_eval_run = False + + if args: + args = [self._convert_to_fp16(arg) for arg in args] + if kwargs: + for k, v in kwargs.items(): + kwargs[k] = self._convert_to_fp16(v) + + out = self.model(*args, **kwargs) + + if self._output_to_fp32: + if isinstance(out, Tensor): + out = self._convert_to_fp32(out) + elif isinstance(out, (tuple, list)): + out = [self._convert_to_fp32(val) for val in out] + elif isinstance(out, dict): + out = {key: self._convert_to_fp32(val) for key, val in out.items()} + return out diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..893cc890d68e423c643e6dc4bbf6343ff174a8d7 --- /dev/null +++ b/colossalai/amp/torch_amp/__init__.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer + + +def convert_to_torch_amp(model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + amp_config: Optional[Config] = None): + """A helper function to wrap training components with Pytorch AMP modules + + Args: + model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + criterion (:class:`torch.nn.modules.loss._Loss`, optional): your loss function object + amp_config (:class:`colossalai.context.Config` or dict, optional): configuration for Pytorch AMP. + + The ``amp_config`` should include parameters below: + :: + + init_scale (float, optional, default=2.**16) + growth_factor (float, optional, default=2.0) + backoff_factor (float, optional, default=0.5) + growth_interval (int, optional, default=2000) + enabled (bool, optional, default=True) + + Returns: + A tuple (model, optimizer, criterion) + """ + model = TorchAMPModel(model) + if amp_config is None: + amp_config = dict() + optimizer = TorchAMPOptimizer(optimizer, **amp_config) + if criterion: + criterion = TorchAMPLoss(criterion) + return model, optimizer, criterion + + +__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer'] diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..7b78998fb8c233f13f34fdf64df95bdfd1601ee6 --- /dev/null +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py +# to support tensor parallel + +import warnings +from collections import abc, defaultdict +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from packaging import version +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + if enabled and not torch.cuda.is_available(): + warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + self._enabled = False + else: + self._enabled = enabled + + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + if torch_version.minor > 8: + self._higher_than_torch18 = True + else: + self._higher_than_torch18 = False + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda or outputs.device.type == 'xla' + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + # holds a reference that can be overwritten by apply_scale + stash: List[_MultiDeviceReplicator] = [] + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda or val.device.type == 'xla' + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), + per_device_inv_scale.get(device)) + # For tensor parallel paramters it should be all-reduced over tensor parallel process group + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + vals = [val for val in per_device_found_inf._per_device_tensors.values()] + coalesced = _flatten_dense_tensors(vals) + dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) + for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)): + buf.copy_(synced) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) + optimizer_state["stage"] = OptState.STEPPED + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + # type: ignore[attr-defined] + assert isinstance(new_scale, torch.cuda.FloatTensor), reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + if self._higher_than_torch18: + torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) + else: + self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async().item() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker() + } if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/amp/torch_amp/torch_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..5074e9c81d35ea71ab5fcca38d49419d78c8694f --- /dev/null +++ b/colossalai/amp/torch_amp/torch_amp.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn +import torch.cuda.amp as torch_amp + +from torch import Tensor +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from ._grad_scaler import GradScaler + +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import clip_grad_norm_fp32 + + +class TorchAMPOptimizer(ColossalaiOptimizer): + """A wrapper class which integrate Pytorch AMP with an optimizer + + Args: + optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD. + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, optim: Optimizer, *args, **kwargs): + super().__init__(optim) + self.scaler = GradScaler(*args, **kwargs) + + def backward(self, loss: Tensor): + """Backward with torch amp gradient scaler + + Args: + loss (torch.Tensor): Loss computed by a loss function + """ + self.scaler.scale(loss).backward() + + def step(self): + """Update the parameters of the model + """ + self.scaler.step(self.optim) + self.scaler.update() + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + """Apply gradient clipping to the model parameters + + Args: + model (torch.nn.Module): Your model object + max_norm (float): Max norm value for gradient clipping + """ + if max_norm > 0.0: + self.scaler.unscale_(self.optim) + clip_grad_norm_fp32(model.parameters(), max_norm) + + +class TorchAMPModel(nn.Module): + """A wrapper class for a model object which executes forward with values automatically + cast to fp16 + + Args: + model (:class:`torch.nn.Module`): a torch model instance + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + @torch_amp.autocast() + def forward(self, *args, **kwargs): + """ + Execute forward under the torch amp context + """ + return self.model(*args, **kwargs) + + +class TorchAMPLoss(nn.Module): + """A wrapper class for a criterion object which computes the loss in mixed-precision context + + Args: + loss (torch.nn.modules.loss._Loss): A loss function object + """ + + def __init__(self, loss: _Loss): + super().__init__() + self.loss = loss + + @torch_amp.autocast() + def forward(self, *args, **kwargs): + """ + Execute forward under the torch amp context + """ + return self.loss(*args, **kwargs) diff --git a/colossalai/auto_parallel/__init__.py b/colossalai/auto_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/checkpoint/__init__.py b/colossalai/auto_parallel/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10ade417a238753af49c3780cd6693b362b7bbb4 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .ckpt_solver_base import CheckpointSolverBase +from .ckpt_solver_chen import CheckpointSolverChen +from .ckpt_solver_rotor import CheckpointSolverRotor diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py new file mode 100644 index 0000000000000000000000000000000000000000..af4349865a7b8dd748e34458eb3d5aeeb359b599 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -0,0 +1,16 @@ +import os + +from setuptools import Extension, setup + +this_dir = os.path.dirname(os.path.abspath(__file__)) +ext_modules = [Extension( + 'rotorc', + sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], +)] + +setup( + name='rotor c extension', + version='0.1', + description='rotor c extension for faster dp computing', + ext_modules=ext_modules, +) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py new file mode 100644 index 0000000000000000000000000000000000000000..63eff31b2da791711872a00a98200710203130e2 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -0,0 +1,171 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, List + +import torch +from torch.fx import Graph, Node + +from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen +from colossalai.fx.profiler.memory_utils import is_inplace + +__all___ = ['CheckpointSolverBase'] + + +def _copy_output(src: Graph, dst: Graph): + """Copy the output node from src to dst""" + for n_src, n_dst in zip(src.nodes, dst.nodes): + if n_src.op == 'output': + n_dst.meta = n_src.meta + + +def _get_param_size(module: torch.nn.Module): + """Get the size of the parameters in the module""" + return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()]) + + +class CheckpointSolverBase(ABC): + + def __init__( + self, + graph: Graph, + free_memory: float = -1.0, + requires_linearize: bool = False, + cnode: List[str] = None, + ): + """CheckpointSolver class will integrate information provided by the components + and use an existing solver to find a possible optimal strategies combination for + target computing graph. + + Existing Solvers: + Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) + Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor) + + Args: + graph (Graph): The computing graph to be optimized. + free_memory (float): Memory constraint for the solution. + requires_linearize (bool): Whether the graph needs to be linearized. + cnode (List[str], optional): Common node List, should be the subset of input. Default to None. + + Warnings: + `MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. + """ + # super-dainiu: this graph is a temporary graph which can refer to + # the owning module, but we will return another deepcopy of it after + # the solver is executed. + self.graph = deepcopy(graph) + self.graph.owning_module = graph.owning_module + _copy_output(graph, self.graph) + self.graph.set_codegen(ActivationCheckpointCodeGen()) + + # check if `MetaInfoProp` is done + if any(len(node.meta) == 0 for node in self.graph.nodes): + raise RuntimeError( + "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") + + self.free_memory = free_memory + self.parameter_size = _get_param_size(self.graph.owning_module) + self.cnode = cnode + self.requires_linearize = requires_linearize + if self.requires_linearize: + self.node_list = self._linearize_graph() + else: + self.node_list = self.get_node_list() + + @abstractmethod + def solve(self): + """Solve the checkpointing problem and return the solution. + """ + pass + + def get_node_list(self): + """Get the node list. + """ + return [[node] for node in self.graph.nodes] + + def _linearize_graph(self) -> List[List[Node]]: + """Linearizing the graph + + Args: + graph (Graph): The computing graph to be optimized. + + Returns: + List[List[Node]]: List of list, each inside list of Node presents + the actual 'node' in linearized manner. + + Remarks: + Do merge the inplace ops into the previous node. + """ + + # Common nodes are type of nodes that could be seen as attributes and remain + # unchanged throughout the whole model, it will be used several times by + # different blocks of model, so that it is hard for us to linearize the graph + # when we encounter those kinds of nodes. We let users to annotate some of the + # input as common node, such as attention mask, and the followings are some of + # the ops that could actually be seen as common nodes. With our common node prop, + # we could find some of the "real" common nodes (e.g. the real attention mask + # used in BERT and GPT), the rule is simple, for node who's parents are all common + # nodes or it's op belongs to the following operations, we view this node as a + # newly born common node. + # List of target name that could be seen as common node + common_ops = ["getattr", "getitem", "size"] + + def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in common_ops + else: + return target.__name__ in common_ops + + def _is_sink() -> bool: + """Check if we can free all dependencies + + Returns: + bool + """ + + return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + + # make sure that item in cnode is valid + if self.cnode: + for name in self.cnode: + try: + assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ + f"Common node {name} is not an input of the model." + except StopIteration: + raise ValueError(f"Common node name {name} not in graph.") + + else: + self.cnode = [] + + deps = {} + node_list = [] + region = [] + + for n in self.graph.nodes: + if n.op != "placeholder" and n.op != "output": + for n_par in n.all_input_nodes: + if n_par.op != "placeholder" and n_par.name not in self.cnode: + deps[n_par] -= 1 + region.append(n) + + # if the node could free all dependencies in graph + # we could begin a new node + if _is_sink(): + node_list.append(region) + region = [] + + # propagate common node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode + ]) or _is_cop(n.target): + self.cnode.append(n.name) + else: + deps[n] = len([user for user in n.users if user.op != "output"]) + return node_list diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py new file mode 100644 index 0000000000000000000000000000000000000000..58878253e99e1bbbf045e01b4f4fd72bf1112bf1 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -0,0 +1,87 @@ +import math +from copy import deepcopy +from typing import List, Set, Tuple + +from torch.fx import Graph, Node + +from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp + +from .ckpt_solver_base import CheckpointSolverBase + +__all__ = ['CheckpointSolverChen'] + + +class CheckpointSolverChen(CheckpointSolverBase): + + def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + Note that this algorithm targets at memory optimization only, using techniques in appendix A. + + Usage: + Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + to the graph to retrieve all information needed, then we could use the following + code to find a solution using `CheckpointSolverChen`: + >>> solver = CheckpointSolverChen(gm.graph) + >>> chen_graph = solver.solve() + >>> gm.graph = chen_graph # set the graph to a new graph + + Args: + graph (Graph): The computing graph to be optimized. + cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. + num_grids (int, optional): Number of grids to search for b. Defaults to 6. + """ + super().__init__(graph, 0, 0, True, cnode) + self.num_grids = num_grids + + def solve(self) -> Graph: + """Solve the checkpointing problem using Algorithm 3. + + Returns: + graph (Graph): The optimized graph, should be a copy of the original graph. + """ + checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] + ckpt = self.grid_search() + for i, seg in enumerate(ckpt): + for idx in range(*seg): + nodes = self.node_list[idx] + for n in nodes: + if n.op in checkpointable_op: + n.meta['activation_checkpoint'] = i + return deepcopy(self.graph) + + def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + """ + ckpt_intv = [] + temp = 0 + x = 0 + y = 0 + prev_idx = 2 + for idx, nodes in enumerate(self.node_list): + for n in nodes: + n: Node + temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) + y = max(y, temp) + if temp > b and idx > prev_idx: + x += calculate_fwd_in(nodes[0]) + temp = 0 + ckpt_intv.append((prev_idx, idx + 1)) + prev_idx = idx + 1 + return ckpt_intv, math.floor(math.sqrt(x * y)) + + def grid_search(self) -> Set: + """ + Search ckpt strategy with b = 0, then run the allocation algorithm again with b = โˆšxy. + Grid search over [โˆš2/2 b, โˆš2 b] for ckpt_opt over num_grids as in appendix A. + """ + _, b_approx = self.run_chen_greedy(0) + b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) + b_opt = math.inf + for b in range(b_min, b_max, (b_max - b_min) // self.num_grids): + ckpt_intv, b_approx = self.run_chen_greedy(b) + if b_approx < b_opt: + b_opt = b_approx + ckpt_opt = ckpt_intv + return ckpt_opt diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c new file mode 100644 index 0000000000000000000000000000000000000000..0fdcfd58a399f73f091a963af027e637b9761d7b --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c @@ -0,0 +1,197 @@ +#define PY_SSIZE_T_CLEAN +#include + +long* PySequenceToLongArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + long* result = (long*)calloc(len + 1, sizeof(long)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyLong_AsLong(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +double* PySequenceToDoubleArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + double* result = (double*)calloc(len + 1, sizeof(double)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyFloat_AsDouble(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +long* getLongArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + long* result = PySequenceToLongArray(sequence); + Py_DECREF(sequence); + return result; +} + +double* getDoubleArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + double* result = PySequenceToDoubleArray(sequence); + Py_DECREF(sequence); + return result; +} + +static PyObject* computeTable(PyObject* self, PyObject* args) { + PyObject* chainParam; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL; + + double* ftime = getDoubleArray(chainParam, "ftime"); + if (!ftime) return NULL; + + double* btime = getDoubleArray(chainParam, "btime"); + if (!btime) return NULL; + + long* x = getLongArray(chainParam, "x"); + if (!x) return NULL; + + long* xbar = getLongArray(chainParam, "xbar"); + if (!xbar) return NULL; + + long* ftmp = getLongArray(chainParam, "btmp"); + if (!ftmp) return NULL; + + long* btmp = getLongArray(chainParam, "btmp"); + if (!btmp) return NULL; + + long chainLength = PyObject_Length(chainParam); + if (!chainLength) return NULL; + +#define COST_TABLE(m, i, l) \ + costTable[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + double* costTable = (double*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double)); + +#define BACK_PTR(m, i, l) \ + backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + long* backPtr = (long*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long)); + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chainLength; ++i) + if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) && + (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) + COST_TABLE(m, i, i) = ftime[i] + btime[i]; + else + COST_TABLE(m, i, i) = INFINITY; + + for (long m = 0; m <= mmax; ++m) + for (long d = 1; d <= chainLength; ++d) { + for (long i = 0; i <= chainLength - d; ++i) { + long idx = i + d; + long mmin = x[idx + 1] + x[i + 1] + ftmp[i]; + if (idx > i + 1) { + long maxCostFWD = 0; + for (long j = i + 1; j < idx; j++) { + maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]); + } + mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD); + } + if ((m >= mmin)) { + long bestLeaf = -1; + double sumFw = 0; + double bestLeafCost = INFINITY; + for (long j = i + 1; j <= idx; ++j) { + sumFw += ftime[j - 1]; + if (m >= x[j]) { + double cost = sumFw + COST_TABLE(m - x[j], j, idx) + + COST_TABLE(m, i, j - 1); + if (cost < bestLeafCost) { + bestLeafCost = cost; + bestLeaf = j; + } + } + } + double chainCost = INFINITY; + if (m >= xbar[i + 1]) + chainCost = + COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx); + if (bestLeafCost <= chainCost) { + COST_TABLE(m, i, idx) = bestLeafCost; + BACK_PTR(m, i, idx) = bestLeaf; + } else { + COST_TABLE(m, i, idx) = chainCost; + BACK_PTR(m, i, idx) = -1; + } + } else + COST_TABLE(m, i, idx) = INFINITY; + } + } + + free(ftime); + free(btime); + free(x); + free(xbar); + free(ftmp); + free(btmp); + + PyObject* pyCostTable = PyList_New(mmax + 1); + PyObject* pyBackPtr = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* pyCostTable_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyCostTable, m, pyCostTable_m); + PyObject* pyBackPtr_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m); + for (long i = 0; i <= chainLength; ++i) { + PyObject* pyCostTable_m_i = PyDict_New(); + PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i); + PyObject* pyBackPtr_m_i = PyDict_New(); + PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i); + for (long l = i; l <= chainLength; ++l) { + PyObject* pyVar_l = PyLong_FromLong(l); + PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l)); + PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l); + Py_DECREF(pyCostTable_m_i_l); + PyObject* pyBackPtr_m_i_l; + if (BACK_PTR(m, i, l) < 0) + pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True); + else + pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l)); + PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l); + Py_DECREF(pyBackPtr_m_i_l); + Py_DECREF(pyVar_l); + } + } + } + + free(costTable); + free(backPtr); + + PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr); + Py_DECREF(pyCostTable); + Py_DECREF(pyBackPtr); + return result; +} + +static PyMethodDef rotorMethods[] = { + {"compute_table", computeTable, METH_VARARGS, + "Compute the optimal table with the rotor algorithm."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef rotorModule = { + PyModuleDef_HEAD_INIT, "rotorc", /* name of module */ + "A simple implementation of dynamic programming algorithm rotor with C in " + "https://hal.inria.fr/hal-02352969. Some code are adapted from " + "https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be + NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + rotorMethods}; + +PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); } diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py new file mode 100644 index 0000000000000000000000000000000000000000..72bc67e02e888dc3ab18f9fc8e7db510e65d41bd --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -0,0 +1,426 @@ +from copy import deepcopy +from typing import Any, Dict, List, Tuple + +from torch import Tensor +from torch.fx import Graph, Node + +from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.profiler import ( + activation_size, + calculate_bwd_time, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, +) +from colossalai.logging import get_dist_logger + +from .ckpt_solver_base import CheckpointSolverBase +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence + +__all__ = ['CheckpointSolverRotor'] + + +class CheckpointSolverRotor(CheckpointSolverBase): + + def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500): + """This is the simple implementation of dynamic programming algorithm rotor + in https://hal.inria.fr/hal-02352969. Some code are adapted from + https://gitlab.inria.fr/hiepacs/rotor. + + Usage: + Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + to the graph to retrieve all information needed, then we could use the following + code to find a solution using `CheckpointSolverRotor`: + >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) + >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver + >>> gm.graph = rotor_graph # set the graph to a new graph + + Args: + graph (Graph): The computing graph to be optimized. + free_memory (float, optional): Memory constraint for the solution, unit is byte. + Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. + cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. + memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + """ + super().__init__(graph, free_memory, True, cnode) + self.memory_slots = memory_slots + + # construct chain + unit = self.free_memory // self.memory_slots + self.chain = self._construct_chain(self.graph, self.node_list) + self.chain.discretize_all(unit) + + self.cost_table = None + self.back_ptr = None + self.sequence = None + + def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: + """Solve the checkpointing problem using rotor algorithm. + + Args: + force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False. + verbose (bool, optional): Print verbose information. Defaults to False. + + Returns: + graph (Graph): The optimized graph, should be a copy of the original graph. + """ + chain = self.chain + + # compute cost table + if force_python: + self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots) + else: + self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots) + + if verbose: + self.print_chain() + + # backtrack + try: + self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, + self.back_ptr) + self._annotate_from_sequence(self.sequence, self.node_list) + except ValueError as e: + # using logger to annonce that the solver is failed + logger = get_dist_logger() + logger.warning(f'Checkpoint solver failed: {e}') + raise ValueError + + if verbose: + self.print_sequence() + + return deepcopy(self.graph) + + def print_chain(self): + print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) + for idx in range(len(self.node_list) - 1): + print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], + self.chain.btmp[idx]) + print(f'Chain = {self.chain}') + + def print_sequence(self): + print(f'Sequence = {self.sequence}') + + @classmethod + def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: + input_tensors = cls._extract_input(graph) + ftime, btime, ftmp, btmp = list(), list(), list(), list() + xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)] + + for node in node_list: + node_info = cls._extract_node_info(node) + ftime.append(node_info[0]) + btime.append(node_info[1]) + x.append(node_info[2]) + xbar.append(node_info[3]) + ftmp.append(node_info[4]) + btmp.append(node_info[5]) + + # currently we view loss backward temp as zero + btime.append(0) + btmp.append(0) + + return Chain(ftime, btime, x, xbar, ftmp, btmp) + + @classmethod + def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: + """Extract node info from a list of nodes""" + xbar = 0 + ftime = 0 + btime = 0 + for n in node: + assert isinstance(n, Node), f'{n} is not a Node' + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + # minimum flop count is required + ftime += max(calculate_fwd_time(n), 1.0) + btime += max(calculate_bwd_time(n), 1.0) + + x = calculate_fwd_out(node[-1]) + xbar = max(x, xbar) + ftmp = cls._extract_ftmp(node) + btmp = cls._extract_btmp(node) + return ftime, btime, x, xbar, ftmp, btmp + + @staticmethod + def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: + """Extract input tensors from a Graph""" + input_tensors = [] + for node in graph.nodes: + if node.op == 'placeholder': + input_tensors.append(node.meta['fwd_out']) + return input_tensors + + @staticmethod + def _extract_ftmp(node: List[Node]) -> int: + """Extract ftmp from a list of nodes""" + n = node[-1] + return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) + + @staticmethod + def _extract_btmp(node: List[Node]) -> int: + """Extract btmp from a list of nodes""" + + def _extract_deps_size(): + deps_size = 0 + for k, v in deps.items(): + k: Node + if v > 0: + deps_size += k.meta['bwd_mem_out'] + if v == float('-inf'): + deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) + + return deps_size + + btmp = 0 + deps = {} + for n in reversed(node): + deps[n] = len(n.all_input_nodes) + btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) + for child in n.users: + if child in deps: + deps[child] -= 1 + if deps[child] <= 0: + deps[child] = float('-inf') # free + return btmp + + @staticmethod + def _compute_table(chain: Chain, mmax: int) -> Tuple: + """Compute the table using dynamic programming. Returns the cost table and the backtracking pointer. + + Args: + chain (Chain): A basic linearized structure for solving the dynamic programming problem. + mmax (int): Maximum number of memory slots. + + Returns: + cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length + and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax + back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice + is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint + of length j + """ + + ftime = chain.ftime + [0.0] + btime = chain.btime + x = chain.x + [0] + xbar = chain.xbar + [0] + ftmp = chain.ftmp + [0] + btmp = chain.btmp + [0] + + # Build table + cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation + + # Initialize borders of the tables for lmax-lmin = 0 + for m in range(mmax + 1): + for i in range(len(chain) + 1): + limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) + if m >= limit: # Equation (1) + cost_table[m][i][i] = ftime[i] + btime[i] + else: + cost_table[m][i][i] = float("inf") + + # Compute everything + for m in range(mmax + 1): + for d in range(1, len(chain) + 1): + for i in range(len(chain) + 1 - d): + idx = i + d + mmin = x[idx + 1] + x[i + 1] + ftmp[i] + if idx > i + 1: + mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx))) + if m < mmin: + cost_table[m][i][idx] = float("inf") + else: + leaf_checkpoints = [(j, + sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= x[j]] + if leaf_checkpoints: + best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) + else: + best_leaf = None + if m >= xbar[i + 1]: + chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx] + else: + chain_checkpoint = float("inf") + if best_leaf and best_leaf[1] <= chain_checkpoint: + cost_table[m][i][idx] = best_leaf[1] + back_ptr[m][i][idx] = (False, best_leaf[0]) + else: + cost_table[m][i][idx] = chain_checkpoint + back_ptr[m][i][idx] = (True,) + return cost_table, back_ptr + + @staticmethod + def _compute_table_c(chain: Chain, mmax: int) -> Tuple: + try: + from .rotorc import compute_table + + # build module if module not found + except ModuleNotFoundError: + import os + import subprocess + import sys + logger = get_dist_logger() + logger.info("rotorc hasn't been built! Building library...", ranks=[0]) + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = subprocess.Popen( + [ + f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", + f"--build-lib={this_dir}" + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.wait() == 0: + logger.info("rotorc has been built!", ranks=[0]) + from .rotorc import compute_table + else: + logger.warning("rotorc built failed! Using python version!", ranks=[0]) + return CheckpointSolverRotor._compute_table(chain, mmax) + return compute_table(chain, mmax) + + @staticmethod + def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], + back_ptr: List[Any]) -> "Sequence": + """Backtrack the cost table and retrieve the optimal checkpointing strategy. + + Args: + chain (Chain): A basic linearized structure for solving the dynamic programming problem. + lhs (int): The left index of the interval to backtrack. + rhs (int): The right index of the interval to backtrack. + budget (int): The memory budget for processing this interval. + cost_table (List[Any]): See `._compute_table()` for definitions + back_ptr (List[Any]): See `._compute_table()` for definitions + + Raises: + ValueError: Can not process the chain. + + Returns: + sequence (Sequence): The sequence of executing nodes with checkpoints. + """ + if budget <= 0: + raise ValueError(f"Can not process a chain with negative memory {budget}") + elif cost_table[budget][lhs][rhs] == float("inf"): + raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}") + + sequence = Sequence() + if rhs == lhs: + if lhs == len(chain): + sequence += [Loss()] + else: + sequence += [ForwardEnable(lhs), Backward(lhs)] + return sequence + + if back_ptr[budget][lhs][rhs][0]: + sequence += [ + ForwardEnable(lhs), + CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, + back_ptr), + Backward(lhs), + ] + else: + best_leaf = back_ptr[budget][lhs][rhs][1] + sequence += [ForwardCheck(lhs)] + sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] + sequence += [ + CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, + back_ptr), + CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), + ] + return sequence + + @staticmethod + def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): + """Annotate the nodes in the node_list with activation checkpoint from the sequence. + + Args: + sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. + node_list (List[List[Node]]): The list of nodes to annotate. + """ + op_list = sequence.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + + # forward annotation + for idx, op in enumerate(fwd_list, 0): + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(idx) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'] = [ckpt_idx] + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'] = [ckpt_idx] + + ckpt_idx += 1 + ckpt_region = [idx] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(idx) + + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max( + len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - + len(op_list[idx].meta['activation_checkpoint'])) diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0c6c5ad38d171d470931aa9b2bdedf6cd17668 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -0,0 +1,184 @@ +import math +from abc import ABC +from typing import Any, Iterable, List + +from torch.utils._pytree import tree_map + + +class Chain: + + def __init__(self, + ftime: List[float], + btime: List[float], + x: List[int], + xbar: List[int], + ftmp: List[int], + btmp: List[int], + check_consistency: bool = True): + """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. + See paper https://hal.inria.fr/hal-02352969 for details. + + Args: + ftime (List[float]): The forward time of each node. + btime (List[float]): The backward time of each node. + x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper. + xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper. + ftmp (List[int]): The temporary forward memory of each node. + btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget. + check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True. + """ + self.ftime = ftime + self.btime = btime + self.x = x + self.xbar = xbar + self.ftmp = ftmp + self.btmp = btmp + if check_consistency and not self.check_lengths(): + raise AttributeError("In Chain, input lists do not have consistent lengths") + + def check_lengths(self): + return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1)) + + def __repr__(self): + chain_list = [] + for i in range(len(self)): + chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i])) + i = len(self) + chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i])) + return chain_list.__repr__() + + def __len__(self): + return len(self.ftime) + + def discretize_all(self, unit: int): + """Discretize the chain into a list of chains according to unit size.""" + discretizer = lambda val: math.ceil(val / unit) + self.x = tree_map(discretizer, self.x) + self.xbar = tree_map(discretizer, self.xbar) + self.ftmp = tree_map(discretizer, self.ftmp) + self.btmp = tree_map(discretizer, self.btmp) + + +class Operation(ABC): + name = "Op" + + def __repr__(self) -> str: + return f"{self.name}_{self.index}" + + def shift(self, value): + if type(self.index) is tuple: + self.index = tuple(x + value for x in self.index) + else: + self.index += value + + +class Forward(Operation): + name = "F" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + if chain is not None: + return chain.ftime[self.index] + else: + return 1 + + +class ForwardEnable(Forward): + name = "Fe" + + +class ForwardNograd(Forward): + name = "Fn" + + +class ForwardCheck(Forward): + name = "CF" + + +class Forwards(Operation): + + def __init__(self, start, end): + self.index = (start, end) + + def __repr__(self): + return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) + + def cost(self, chain: Chain): + if chain is not None: + return sum(chain.ftime[self.index[0]:self.index[1] + 1]) + else: + return (self.index[1] - self.index[0] + 1) + + +def isForward(op): + return type(op) is Forward or type(op) is Forwards + + +class Backward(Operation): + name = "B" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + if chain is not None: + return chain.btime[self.index] + else: + return 1 + + +class Loss(Operation): + + def __init__(self): + pass + + def __repr__(self): + return "L" + + def cost(self, chain): + return 0 + + +class MemoryAccess(Operation): + name = "MA" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + return 0 + + +class WriteMemory(MemoryAccess): + name = "WM" + + +class ReadMemory(MemoryAccess): + name = "RM" + + +class DiscardMemory(MemoryAccess): + name = "DM" + + +class Sequence(list): + + def __init__(self): + super().__init__() + + def __repr__(self): + return repr(self.list_operations()) + + def list_operations(self): + op_list = [] + for x in self: + if isinstance(x, Operation): + op_list.append(x) + else: + assert isinstance(x, Sequence) + op_list += x.list_operations() + return op_list diff --git a/colossalai/auto_parallel/meta_profiler/__init__.py b/colossalai/auto_parallel/meta_profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd36195149b3077a8710e73c927501c96a4b9eb --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/__init__.py @@ -0,0 +1,3 @@ +from .meta_registry import * +from .metainfo import * +from .registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..714674b7b42534974af805c73f3b98ec6e5c1482 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -0,0 +1,12 @@ +import operator + +import torch +import torch.nn as nn + +from ..tensor_shard.constants import * + +# list of inplace operations +INPLACE_MODULE = [nn.ReLU] + +# list of operations that do not save forward activations +NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5f77f6591e077341b4c4363f4ac850e158bec5 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -0,0 +1,6 @@ +from .activation import * +from .binary_elementwise_ops import * +from .conv import * +from .linear import * +from .norm import * +from .pooling import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..dc62005f0906418c201588e7a7b0dafbef3424b7 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -0,0 +1,70 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["relu_meta_info"] + + +@meta_register.register(torch.nn.ReLU) +def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.ReLU metainfo generator + The aten graph of torch.nn.ReLU is + graph(): + %input_2 : [#users=1] = placeholder[target=placeholder](default=) + %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {}) + %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + inplace = kwargs.get("inplace", False) + + # construct input args for forward + fwd_in_args = [input_tensor] + + # construct input args for backward + bwd_in_args = [output_tensor] + + # calculate cost + # the fwd op with compute cost is relu.default + # the bwd op with compute cost is threshold_backward + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: the inplace ReLU don't have forward memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) if inplace else activation_size([output_tensor, input_tensor]), + parameter=0, + temp=0, + buffer=0) + + bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in + fwd_in = [input_tensor] + + return compute_cost, memory_cost, fwd_in diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0292121b60bf03c664521550d8054ff9dbb759a0 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -0,0 +1,65 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..constants import BCAST_FUNC_OP +from ..registry import meta_register + +__all__ = ['binary_elementwise_meta_info'] + + +@meta_register.register(BCAST_FUNC_OP) +def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta information generator for binary elementwise operations + NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they + don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`, + they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify + this behavior, it is critical for better memory estimation. + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) + + # construct forward args for flop mapping + fwd_in_args = [input_op_data.data, other_op_data.data] + fwd_out_args = [output_op_data.data] + + # calculate cost + + # calculate compute cost + # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case + fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + param_mem_cost = activation_size( + [arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM]) + fwd_mem_cost = MemoryCost( + activation=activation_size([input_op_data.data, output_op_data.data]), + parameter=param_mem_cost, + ) + bwd_mem_cost = MemoryCost( + activation=activation_size(fwd_in_args), + parameter=param_mem_cost, + ) + + # total cost + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + ) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in + fwd_in = fwd_in_args + + return compute_cost, memory_cost, fwd_in diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d55529fb9c73221b1e916c3ecd32cd37a72d14 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -0,0 +1,132 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['convnd_meta_info'] + + +@meta_register.register(torch.nn.Conv1d) +@meta_register.register(torch.nn.Conv2d) +@meta_register.register(torch.nn.Conv3d) +@meta_register.register(torch.nn.functional.conv1d) +@meta_register.register(torch.nn.functional.conv2d) +@meta_register.register(torch.nn.functional.conv3d) +def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator + The atens graph of torch.nn.Convnd with bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + %detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {}) + + The atens graph of torch.nn.Convnd without bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + has_bias: bool = False + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] + + # check if conv has bias + if len(weight_tensors) > 1: + has_bias = True + # bias tensor's shape only has one dimension + if len(weight_tensors[0].shape) == 1: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor, bias_tensor = weight_tensors + + else: + weight_tensor = weight_tensors[0] + + # construct input args for forward + fwd_args = [None] * 9 + + # weight and input + fwd_args[0] = input_tensor + fwd_args[1] = weight_tensor + fwd_args[2] = bias_tensor if has_bias else None + + # transpose indicator should be set to False + fwd_args[6] = False + + # construct input args for backward + bwd_args = [None] * 11 + + # weight and input + bwd_args[0] = output_tensor + bwd_args[1] = input_tensor + bwd_args[2] = weight_tensor + bwd_args[-1] = [True, True, True] if has_bias else [True, True, False] + + # calculate cost + # the fwd op with compute cost is convolution.default + # the bwd op with compute cost is convolution_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ + flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # TODO: use profiler to check conv temp memory + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost( + activation=activation_size([input_tensor, output_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), + temp=0, + buffer=0) + + bwd_memory_cost = MemoryCost( + activation=activation_size([input_tensor, weight_tensor, bias_tensor]) + if has_bias else activation_size([input_tensor, weight_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), + temp=0, + buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in + fwd_in = [input_tensor] + + return compute_cost, memory_cost, fwd_in diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..b48748fa9826ef9f3374d2f87359667b0e871722 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -0,0 +1,166 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['linear_meta_info'] + + +@meta_register.register(torch.nn.functional.linear) +@meta_register.register(torch.nn.Linear) +def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Linear & torch.nn.functional.linear meta info generator + NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator, + but we will hold the bias mechanism in the linear metainfo generator for future use. + + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {}) + %zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {}) + %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {}) + %mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {}) + %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {}) + %sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {}) + %view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {}) + %detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {}) + %detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {}) + + The one without bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {}) + %zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {}) + %mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {}) + %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {}) + %mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs + """ + + has_bias: bool = False + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] + + # process the dimension of input and output + if len(input_tensor.shape) > 2: + input_tensor: torch.Tensor + input_tensor = input_tensor.view(-1, input_tensor.shape[-1]) + + if len(output_tensor.shape) > 2: + output_tensor: torch.Tensor + output_tensor = output_tensor.view(-1, output_tensor.shape[-1]) + + if len(weight_tensors) > 1: + has_bias = True + if len(weight_tensors[0].shape) == 2: + weight_tensor, bias_tensor = weight_tensors + else: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor = weight_tensors[0] + + if has_bias: + # calculate cost with bias + # the fwd op with compute cost is addmm + # the bwd op with compute cost is mm * 2 and sum.dim_IntList + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( + [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ + flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: Linear don't have buffer and temp in forward and backward phase + # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=0, + buffer=0) + + # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 + bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=0, + buffer=0) + + # total cost is to sum the forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + else: + # calculate cost without bias + # the fwd op with compute cost is mm + # the bwd op with compute cost is mm * 2 + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ + flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: Linear don't have buffer and temp in forward and backward phase + # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=activation_size(weight_tensor), + temp=0, + buffer=0) + + # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 + bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]), + parameter=activation_size(weight_tensor), + temp=0, + buffer=0) + + # total cost is to sum the forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in + fwd_in = [input_tensor] + + return compute_cost, memory_cost, fwd_in diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..395eecdbb8c5da550f5cbe0d46279aad12eea9d5 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -0,0 +1,101 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['batchnormnd_meta_info'] + + +@meta_register.register(torch.nn.BatchNorm1d) +@meta_register.register(torch.nn.BatchNorm2d) +@meta_register.register(torch.nn.BatchNorm3d) +def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator + The aten graph of BatchNorm2d is like + + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {}) + %detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {}) + %detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {}) + %detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {}) + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + weight_tensor = next(filter(lambda x: x.name == "weight", args)).data + bias_tensor = next(filter(lambda x: x.name == "bias", args)).data + mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data + var_tensor = next(filter(lambda x: x.name == "running_var", args)).data + num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data + + # construct fwd args + # the fwd inputs are input, weight, bias, running_mean, running_var and some other args + # indicating the status of the module + # the fwd outputs are output, saved mean, saved inv std and num batches tracked + fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5] + fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch] + + # construct bwd args + # the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean, + # saved inv std and some other args indicating the status of the module + # the bwd outputs are input grad, weight grad and bias grad + bwd_in_args = [ + output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch + ] + bwd_out_args = [input_tensor, weight_tensor, bias_tensor] + + # calculate cost + fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # the fwd activation cost is output plus saved mean and saved inv std + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=0, + buffer=activation_size([mean_tensor, var_tensor])) + + # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean + # and saved inv std during backward phase + bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=activation_size([mean_tensor, var_tensor]), + buffer=activation_size([mean_tensor, var_tensor])) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in + fwd_in = [input_tensor] + + return compute_cost, memory_cost, fwd_in diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..63f321519772d7dc702015c57e030140d8c0481e --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -0,0 +1,128 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["avgpool_meta_info", "maxpool_meta_info"] + + +@meta_register.register(torch.nn.AdaptiveAvgPool1d) +@meta_register.register(torch.nn.AdaptiveAvgPool2d) +@meta_register.register(torch.nn.AdaptiveAvgPool3d) +def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta info for AdaptiveAvgPool + The aten graph of AdaptiveAvgPool is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # construct forward args for flop mapping + fwd_in_args = [input_tensor] + fwd_out_args = [output_tensor] + + # construct backward args for flop mapping + bwd_in_args = [output_tensor] + bwd_out_args = [input_tensor] + + # calculate cost + # the fwd op with compute cost is _adaptive_avg_pool2d.default + # the bwd op with compute cost is _adaptive_avg_pool2d_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensor)) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor)) + + # total cost + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation) + + mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store_fwd_in + fwd_in = [input_tensor] + + return compute_cost, mem_cost, fwd_in + + +@meta_register.register(torch.nn.MaxPool1d) +@meta_register.register(torch.nn.MaxPool2d) +@meta_register.register(torch.nn.MaxPool3d) +def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta info for MaxPool + The aten graph of MaxPool is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {}) + %max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {}) + %detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {}) + %detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # construct forward args for flop mapping + fwd_in_args = [input_tensor] + fwd_out_args = [output_tensor] + + # construct backward args for flop mapping + bwd_in_args = [output_tensor] + bwd_out_args = [input_tensor] + + # construct index matrix + index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64) + + # calculate cost + # the fwd op with compute cost is max_pool2d_with_indices.default + # the bwd op with compute cost is max_pool2d_with_indices_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: the index matrix will be discarded in backward phase + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix])) + + # temp memory for backward is the index matrix to be discarded + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix), + temp=activation_size(index_matrix)) + + # total cost + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) + + mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store_fwd_in + fwd_in = [input_tensor] + + return compute_cost, mem_cost, fwd_in diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cbc57bd6ff462a20572b5207058cec8fad4c2e --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -0,0 +1,121 @@ +from typing import Callable + +import numpy as np +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION +from .registry import meta_register + +__all__ = ['MetaInfo'] + + +class MetaInfo: + """MetaInfo class + This class is used to store meta info based on sharding strategy and the given + target function. + """ + + def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None: + # compute cost of forward and backward computation + self.compute_cost: TrainCycleItem + + # compute memory cost of forward and backward phase + self.memory_cost: TrainCycleItem + + # list of input tensors + self.fwd_in: list[OperationData] + + # bool type to indicate whether the function will save forward activation + self.save_fwd_in: bool + + # sharding strategy + self._strategy = strategy + + # target function + self._target = target + + # compute metainfo if possible + if self._strategy is not None and self._target is not None: + self.compute_metainfo() + + @property + def strategy(self) -> ShardingStrategy: + return self._strategy + + @property + def target(self) -> Callable: + return self._target + + @strategy.setter + def strategy(self, strategy: ShardingStrategy) -> None: + self._strategy = strategy + if self._strategy is not None and self._target is not None: + self.compute_metainfo() + + @target.setter + def target(self, target: Callable) -> None: + self._target = target + if self._strategy is not None and self._target is not None: + self.compute_metainfo() + + def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: + """ + Compute sharded meta tensor based on the given data and sharding spec. + """ + shard_sequnce = sharding_spec.sharding_sequence + device_mesh = sharding_spec.device_mesh + shape = operation_data.data.shape + + new_shape = [] + for dim, shard in zip(shape, shard_sequnce): + if shard.is_replica: + # replica + new_shape.append(dim) + else: + # sharded according to device_mesh shape + new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list]))) + + return OperationData(name=operation_data.name, + data=torch.zeros(new_shape, device="meta"), + type=operation_data.type, + logical_shape=operation_data.logical_shape) + + def compute_metainfo(self): + """ + Compute meta info based on sharding strategy and the given target function. + """ + + try: + # module + meta_func = meta_register.get(self._target.__class__) + + # check whether the target in the module list that we don't need to save activation + self.save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION + except: + # function + meta_func = meta_register.get(self._target) + + # check whether the target in the module list that we don't need to save activation + self.save_fwd_in = self._target not in NO_SAVE_ACTIVATION + + # construct args for meta_func + args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] + + # construct kwargs + if self.target in INPLACE_MODULE: + kwargs = {'inplace': self.target.inplace} + else: + kwargs = {'inplace': False} + + # compute metainfo with meta_func + self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args, **kwargs) diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..46350c4dd406691c344eb92a933636d6b029b8bd --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/registry.py @@ -0,0 +1,32 @@ +__all__ = ['Registry'] + + +class Registry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store, f'{source} not found in the {self.name} registry' + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_register = Registry('meta') diff --git a/colossalai/auto_parallel/passes/__init__.py b/colossalai/auto_parallel/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..b81402c27fd19e5236f65dd35736c96bd1255aa1 --- /dev/null +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -0,0 +1,219 @@ +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int): + """ + This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into + the user node expected form. + """ + origin_sharding_spec = origin_dict[node_index] + target_sharding_spec = input_dict[node_index][user_node_index] + return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) + + +def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, + user_node_index: int): + """ + This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list + is converted into the user node expected form. + """ + rst = [] + for index, (origin_sharding_spec, + target_sharding_spec) in enumerate(zip(origin_dict[node_index], + input_dict[node_index][user_node_index])): + rst.append( + shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, + target_sharding_spec)) + rst = type(node)(rst) + return rst + + +def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): + """ + This method will be invoked during runtime to apply the comm action following the instruction of comm spec. + """ + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + rst = comm_action.comm_spec.covert_spec_to_action(tensor) + else: + origin_sharding_spec = comm_action.comm_spec['src_spec'] + tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) + return rst + + +def _preprocess_graph(nodes: List[Node]): + """ + This method is used to extract all the placeholders with sharding information, + and mapping the nodes into the index of the origin graph. + """ + # mapping the node into the origin graph index + node_to_index_dict = {} + index = 0 + for node in nodes: + if node.target == 'sharding_spec_convert_dict': + input_dict_node = node + continue + if node.target == 'origin_node_sharding_spec_dict': + origin_dict_node = node + continue + if node.target == 'comm_actions_dict': + comm_actions_dict_node = node + continue + if not hasattr(node, 'best_strategy'): + continue + node_to_index_dict[node] = index + index += 1 + + return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict + + +def _shape_consistency_apply(gm: torch.fx.GraphModule): + """ + This pass is used to add the shape consistency node to the origin graph. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) + + for node in nodes: + if not hasattr(node, 'best_strategy') or node.op == 'output': + continue + + for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): + if isinstance(node.sharding_spec, (list, tuple)): + assert isinstance( + node.target_sharding_specs, + (list, + tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list' + total_difference = 0 + for sharding_spec, target_sharding_spec in zip(node.sharding_spec, + node.target_sharding_specs[user_node_index]): + total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) + if total_difference == 0: + continue + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', + runtime_apply_for_iterable_object, + args=(node, origin_dict_node, input_dict_node, + node_to_index_dict[node], user_node_index)) + + else: + assert isinstance(node.sharding_spec, + ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' + if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: + continue + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', + runtime_apply, + args=(node, origin_dict_node, input_dict_node, + node_to_index_dict[node], user_node_index)) + + new_args = list(user_node.args) + new_kwargs = dict(user_node.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with shape_consistency_node + origin_index_args = new_args.index(node) + new_args[origin_index_args] = shape_consistency_node + user_node.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with shape_consistency_node + new_kwargs[str(node)] = shape_consistency_node + user_node.kwargs = new_kwargs + + return gm + + +def _comm_spec_apply(gm: torch.fx.GraphModule): + """ + This pass is used to add the comm spec apply node to the origin graph. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) + + for node in nodes: + if not hasattr(node, 'best_strategy') or node.op == 'output': + continue + + comm_actions = node.best_strategy.communication_actions + for op_data, comm_action in comm_actions.items(): + + if comm_action.comm_type == CommType.HOOK: + continue + if comm_action.comm_type == CommType.BEFORE: + if op_data.type == OperationDataType.OUTPUT: + comm_object = node + elif comm_action.key_for_kwarg is not None: + comm_object = node.kwargs[comm_action.key_for_kwarg] + else: + comm_object = node.args[comm_action.arg_index] + with mod_graph.inserting_before(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + # the origin node may be a positional argument or key word argument of user node + if comm_action.key_for_kwarg is not None: + # substitute the origin node with comm_spec_apply_node + new_kwargs = dict(node.kwargs) + new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node + node.kwargs = new_kwargs + else: + # substitute the origin node with comm_spec_apply_node + new_args = list(node.args) + new_args[comm_action.arg_index] = comm_spec_apply_node + node.args = tuple(new_args) + + elif comm_action.comm_type == CommType.AFTER: + with mod_graph.inserting_after(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(node, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + user_list = list(node.users.keys()) + for user in user_list: + if user == comm_spec_apply_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with comm_spec_apply_node + new_args[new_args.index(node)] = comm_spec_apply_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with comm_spec_apply_node + new_kwargs[str(node)] = comm_spec_apply_node + user.kwargs = new_kwargs + return gm + + +def runtime_apply_pass(gm: torch.fx.GraphModule): + """ + The method manages all the passes acting on the distributed training runtime. + """ + gm = _shape_consistency_apply(gm) + gm = _comm_spec_apply(gm) + + return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..92916118bf9b9c61ed2388462bcae89aa321521c --- /dev/null +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -0,0 +1,459 @@ +import operator +from copy import deepcopy +from typing import Dict, List, Union + +import torch +from torch.fx import symbolic_trace +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationDataType, + ShardingStrategy, +) +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import _all_reduce +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def size_processing(size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None): + """ + This method will be invoked during runtime to convert size node value depending on distributed information. + """ + if target_dim is not None: + assert isinstance(size, int) + if target_dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[target_dim]: + total_shard_size *= device_mesh_info[shard_dim] + size = size * total_shard_size + + else: + size = list(size) + for dim, dim_size in enumerate(size): + if dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[dim]: + total_shard_size *= device_mesh_info[shard_dim] + size[dim] = dim_size * total_shard_size + size = torch.Size(size) + + return size + + +def _solution_annotatation(gm: torch.fx.GraphModule, + solution: List[int], + strategies_constructor: StrategiesConstructor = None): + """ + This method is used to stick the solution strategy to the nodes and add the information + required in runtime into graph as placeholder nodes. + """ + mod_graph = gm.graph + # TODO: In future PR, strategies_constructor should be a required argument, + # instead of optional argument. This is because we don't need to consider nodes with + # no strategy in runtime preparation pass. + if strategies_constructor is not None: + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + no_strategy_nodes = strategies_constructor.no_strategy_nodes + else: + nodes = tuple(mod_graph.nodes) + no_strategy_nodes = [] + + # the dict to get origin sharding spec of node + origin_node_sharding_spec_dict = {} + for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): + strategies_vector = node.strategies_vector + # stick the solution strategy to the corresponding node + setattr(node, 'best_strategy', strategies_vector[strategy_index]) + setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) + origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( + str(node)) + + # the dict to get input sharding specs of user node + sharding_spec_convert_dict = {} + # the dict to record comm actions of nodes + comm_actions_dict = {} + for index, node in enumerate(nodes): + target_sharding_specs = [] + for user_node in node.strategies_vector.successor_nodes: + if user_node in no_strategy_nodes: + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name)) + else: + target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) + target_sharding_specs.append(target_sharding_spec) + sharding_spec_convert_dict[index] = target_sharding_specs + setattr(node, 'target_sharding_specs', target_sharding_specs) + # the get_attr node strategy is kind of pending strategy, which means we will change it + # to the same strategy of the user node. + if node.op == 'get_attr': + assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + new_sharding_spec = target_sharding_specs[0] + user_strategy = node.strategies_vector.successor_nodes[0].best_strategy + op_data_in_user = user_strategy.get_op_data_by_name(str(node)) + origin_node_sharding_spec_dict[index] = new_sharding_spec + origin_pending_strategy = node.best_strategy + origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node)) + new_sharding_specs = origin_pending_strategy.sharding_specs + new_sharding_specs[origin_op_data] = new_sharding_spec + new_communication_actions = {} + if op_data_in_user in user_strategy.communication_actions: + new_communication_action = user_strategy.communication_actions.pop(op_data_in_user) + new_communication_action.arg_index = 0 + new_communication_actions[origin_op_data] = new_communication_action + new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence), + sharding_specs=new_sharding_specs, + compute_cost=origin_pending_strategy.compute_cost, + communication_cost=origin_pending_strategy.communication_cost, + memory_cost=origin_pending_strategy.memory_cost, + communication_actions=new_communication_actions) + setattr(node, 'best_strategy', new_strategy) + setattr(node, 'sharding_spec', new_sharding_spec) + comm_action_dict = {} + for op_data, comm_action in node.best_strategy.communication_actions.items(): + comm_action_dict[op_data.name] = comm_action + comm_actions_dict[index] = comm_action_dict + + # add above dicts into graph + for node in nodes: + if node.op != 'placeholder': + with mod_graph.inserting_before(node): + input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') + origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') + break + return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict + + +def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + In the auto parallel system, tensors may get shard on different devices, so the size of tensors + need to be converted to the size of original tensor and managed by the users, such as torch.view, + torch.reshape, etc. These nodes have enough information like input sharding_spec and + output sharding_spec to decide how to convert the size value. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + node_pairs = {} + + for node in nodes: + + if node.op == 'call_method' and node.target == 'size': + # extract useful information from size node + # dim_partition_dict will instruct the size value on which + # dimension should be enlarged. + sharding_spec = node.args[0].sharding_spec + dim_partition_dict = sharding_spec.dim_partition_dict + + # there are two usages of torch.Tensor.size: + # tensor.size() + # tensor.size(dim) + # if a target_dim is assigned, then the output will be + # in type of int, instead of torch.Size + target_dim = None + if len(node.args) > 1: + target_dim = node.args[1] + if target_dim < 0: + target_dim += node.args[0]._meta_data.dim() + + # DeviceMesh information instructs the scaling of the size value + device_mesh_info = {} + for dim, dim_size in enumerate(device_mesh.mesh_shape): + device_mesh_info[dim] = dim_size + + with mod_graph.inserting_after(node): + size_processing_node = mod_graph.create_node('call_function', + size_processing, + args=(node, dim_partition_dict, device_mesh_info, + target_dim, node.name)) + # store original node and processing node pair in node_pairs dictioanry + # It will be used to replace the original node with processing node in slice object + node_pairs[node] = size_processing_node + size_processing_node._meta_data = node._meta_data + + user_list = list(node.users.keys()) + for user in user_list: + if user == size_processing_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with size_processing_node + new_args[new_args.index(node)] = size_processing_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with size_processing_node + new_kwargs[str(node)] = size_processing_node + user.kwargs = new_kwargs + + if node.op == 'call_function' and node.target == operator.getitem: + + getitem_index = node.args[1] + # slice object is quite special in torch.fx graph, + # On one side, we treat slice object same as type of int, + # so we do not create a node for slice object. On the other side, + # slice object could take fx.Node as its argument. And the user + # relationship cannot be tracked in fx graph. + # Therefore, I record the node_pairs in this pass, and use the it + # to replace the original node argument inside the slice object if + # it has been processed in above pass. + + # There are three main usages of operator.getitem: + # getitem(input, int) + # getitem(input, slice) + # getitem(input, Tuple[slice]) + # In this pass, we need process the last two cases because + # node arguments may potentially appear in these cases. + if isinstance(getitem_index, slice): + new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step + if getitem_index.start in node_pairs: + new_start = node_pairs[getitem_index.start] + elif getitem_index.stop in node_pairs: + new_stop = node_pairs[getitem_index.stop] + elif getitem_index.step in node_pairs: + new_step = node_pairs[getitem_index.step] + new_slice_item = slice(new_start, new_stop, new_step) + new_args = (node.args[0], new_slice_item) + node.args = new_args + + elif isinstance(getitem_index, (tuple, list)): + assert isinstance(getitem_index[0], slice) + new_slice_items = [] + + for slice_item in getitem_index: + new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step + if slice_item.start in node_pairs: + new_start = node_pairs[slice_item.start] + elif slice_item.stop in node_pairs: + new_stop = node_pairs[slice_item.stop] + elif slice_item.step in node_pairs: + new_step = node_pairs[slice_item.step] + new_slice_item = slice(new_start, new_stop, new_step) + new_slice_items.append(new_slice_item) + + new_args = (node.args[0], tuple(new_slice_items)) + node.args = new_args + + return gm + + +def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + This pass will process node args to adapt the distributed tensor layout. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + for node in nodes: + # skip the placeholder node added in _solution_annotation pass + if not hasattr(node, 'sharding_spec'): + continue + + def _process_sharding_spec(sharding_spec): + if isinstance(sharding_spec, ShardingSpec): + dim_partition_dict = sharding_spec.dim_partition_dict + device_mesh = sharding_spec.device_mesh + return dim_partition_dict, device_mesh + if sharding_spec is None: + return None, None + assert isinstance(sharding_spec, + (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + + device_mesh = sharding_spec[0].device_mesh + dim_partition_dict = [] + for element in sharding_spec: + dim_partition_dict.append(_process_sharding_spec(element)) + return dim_partition_dict, sharding_spec + + output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec) + new_args = [] + + if node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + # process the node with (input, *shape) style args + if method in (torch.Tensor.view, torch.Tensor.reshape): + + for arg in node.args: + if isinstance(arg, Node): + if isinstance(arg._meta_data, (int, tuple, list)): + new_args.append(arg._meta_data) + else: + new_args.append(arg) + else: + assert isinstance( + arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + new_args.append(arg) + + for dim, shard_dims in output_dim_partition_dict.items(): + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + # There are two ways to use torch.view: + # 1. torch.view(input, *shape) + # 2. torch.view(input, shape) + if isinstance(new_args[1], int): + # we will skip the dim with -1 value + if new_args[dim + 1] == -1: + continue + else: + new_args[dim + 1] //= total_shard_size + else: + new_args[1] = list(new_args[1]) + # we will skip the dim with -1 value + if new_args[1][dim] == -1: + continue + else: + new_args[1][dim] //= total_shard_size + node.args = tuple(new_args) + + elif node.op == 'call_function': + target = node.target + # process the node with (input, torch.Size) style args + if target in (torch.reshape,): + for arg in node.args: + if isinstance(arg, Node): + if isinstance(arg._meta_data, (tuple, list)): + new_args.append(list(arg._meta_data)) + else: + new_args.append(arg) + else: + assert isinstance( + arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.' + new_args.append(list(arg)) + + for dim, shard_dims in output_dim_partition_dict.items(): + # we will skip the dim with -1 value + if new_args[1][dim] == -1: + continue + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + new_args[1][dim] //= total_shard_size + node.args = tuple(new_args) + + return gm + + +def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + Apply the sharding action to the module parameters and buffers following the + instructions of solver solution. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + # This stream is created for overlaping the communication and computation. + reduction_stream = torch.cuda.Stream() + for node in nodes: + if node.op == 'call_module': + target_module = node.graph.owning_module.get_submodule(node.target) + + for name, param in target_module.named_parameters(): + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) + # apply the sharding spec of parameters + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) + setattr(param, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + param.data = shape_consistency_manager.apply_for_autoparallel_runtime( + param.data, param.sharding_spec, target_sharding_spec).detach().clone() + + setattr(target_module, name, param) + comm_actions = node.best_strategy.communication_actions + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: + + def wrapper(param, comm_spec): + + def hook_fn(grad): + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(param, comm_spec_to_use) + + sharded_buffer_dict = {} + # apply the sharding spec of buffers + for name, buffer in target_module.named_buffers(): + origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) + setattr(buffer, 'sharding_spec', origin_sharding_spec) + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) + buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) + sharded_buffer_dict[name] = buffer_sharded + + for name, buffer_sharded in sharded_buffer_dict.items(): + setattr(target_module, name, buffer_sharded.detach().clone()) + + if node.op == 'get_attr': + root = node.graph.owning_module + atoms = node.target.split(".") + attr_len = len(atoms) + if attr_len == 1: + target_module = root + target = getattr(root, atoms[0]) + else: + target_module = root.get_submodule(atoms[-2]) + target = getattr(target_module, atoms[-1]) + + target_sharding_spec = node.sharding_spec + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {}) + setattr(target, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + target.data = shape_consistency_manager.apply_for_autoparallel_runtime( + target.data, target.sharding_spec, target_sharding_spec).detach().clone() + + assert hasattr(target_module, atoms[-1]) + setattr(target_module, atoms[-1], target) + + comm_actions = node.best_strategy.communication_actions + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + + def wrapper(param, comm_spec): + + def hook_fn(grad): + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(target, comm_spec_to_use) + return gm + + +def implicit_comm_action_apply(gm: torch.fx.GraphModule): + """ + replace the origin kernel into kernel with implicit communication inside. + """ + pass + + +def runtime_preparation_pass(gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor = None): + gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( + gm, solution, strategies_constructor) + gm = _size_value_converting(gm, device_mesh) + gm = _node_args_converting(gm, device_mesh) + # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. + # gm = implicit_comm_action_apply(gm) + gm = _module_params_sharding(gm, device_mesh) + + return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict diff --git a/colossalai/auto_parallel/pipeline_shard/__init__.py b/colossalai/auto_parallel/pipeline_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/tensor_shard/__init__.py b/colossalai/auto_parallel/tensor_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..99c1249340602daee1a1314f102bc600eae6667d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -0,0 +1,91 @@ +import operator + +import torch + +__all__ = [ + 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', + 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' +] + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + # softmax should not be here + torch.nn.functional.softmax +] +ELEMENTWISE_METHOD_OP = [ + torch.Tensor.to, + torch.Tensor.type, + # TODO: contiguous maybe need some extra processes. + torch.Tensor.contiguous +] +RESHAPE_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.transpose, + torch.split, + torch.permute, + operator.getitem, +] +RESHAPE_METHOD_OP = [ + torch.Tensor.view, + torch.Tensor.unsqueeze, + torch.Tensor.split, + torch.Tensor.permute, + torch.Tensor.transpose, +] +BCAST_FUNC_OP = [ + torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, + operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow +] +CONV_MODULE_OP = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d +] +CONV_FUNC_OP = [ + torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d +] +EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] +LINEAR_MODULE_OP = [torch.nn.Linear] +LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] +BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] +LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] +POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] +NON_PARAM_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + torch.flatten, + torch.where, + operator.pow, + torch.pow, + torch.tanh, + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + # softmax should not be here + torch.nn.functional.softmax +] + +INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a081ce69c10ff1876f8bbe193f94be14aa850d90 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py @@ -0,0 +1,6 @@ +from .options import SolverOptions +from .strategies_constructor import StrategiesConstructor +from .sharding_strategy import ShardingStrategy, StrategiesVector +from .cost_graph import CostGraph +from .solver import Solver +from .graph_analysis import GraphAnalyser \ No newline at end of file diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a72d97554cc608c0943f9f84f5913098c217dd9d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py @@ -0,0 +1,141 @@ +import functools +import operator +import warnings +from functools import reduce +from typing import Dict, List, Optional, Union + +import torch +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from torch.fx.node import Node + +from .constants import INFINITY_COST + + +def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, + dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict. + + + Args: + input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding. + """ + + if isinstance(input_, Node): + assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' + meta_tensor = input_._meta_data + assert meta_tensor is not None, "The given node's _meta_data attribute is None" + shape = meta_tensor.shape + elif isinstance(input_, torch.Tensor): + shape = input_.shape + else: + raise TypeError( + f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + ) + for dim_index, sharding_index_list in dim_partition_dict.items(): + sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + + sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) + return sharding_spec + + +def generate_resharding_costs(nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None): + ''' + Compute the resharding costs with this specific strategy. + + Argument: + nodes (List[Node]): a list of nodes + sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. + count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + if not isinstance(input_sharding_spec, ShardingSpec): + assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + input_sharding_spec = input_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + try: + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes + except AssertionError as e: + warnings.warn(f'{e}') + resharding_cost = INFINITY_COST + resharding_costs[input_node].append(resharding_cost) + return resharding_costs + + +def ignore_sharding_exception(func): + """ + A function wrapper which executes the function with a specified seed. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + rst = func(*args, **kwargs) + return rst + except AssertionError as e: + warnings.warn(f'{e}') + + return wrapper + + +def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): + dim_partition_list = [] + # enumerate all the 2D sharding cases + for i in range(dim_size): + for j in range(i + 1, dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + dim_partition_list.append(dim_partition_dict_1) + for i in range(dim_size): + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + dim_partition_list.append(dim_partition_dict_flatten) + + return dim_partition_list + + +def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): + dim_partition_list = [] + # enumerate all the 1D sharding cases + for i in range(dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + + return dim_partition_list + + +def generate_sharding_size(dim_partition_dict, device_mesh): + total_sharding_size = 1 + for mesh_dim_list in dim_partition_dict.values(): + mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] + sharding_size = reduce(operator.mul, mesh_dim_sharding_size) + total_sharding_size *= sharding_size + + return total_sharding_size diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/constants.py b/colossalai/auto_parallel/tensor_shard/deprecated/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..91c20d3434872a3cbbcc3f538563b8267ef9b9ed --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/constants.py @@ -0,0 +1,83 @@ +import torch +import operator + +__all__ = [ + 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', + 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' +] + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + # softmax should not be here + torch.nn.functional.softmax +] +ELEMENTWISE_METHOD_OP = [ + torch.Tensor.to, + torch.Tensor.type, + # TODO: contiguous maybe need some extra processes. + torch.Tensor.contiguous +] +RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] +RESHAPE_METHOD_OP = [ + torch.Tensor.view, + torch.Tensor.unsqueeze, + torch.Tensor.split, + torch.Tensor.permute, + torch.Tensor.transpose, +] +BCAST_FUNC_OP = [ + torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, + operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh +] +CONV_MODULE_OP = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d +] +CONV_FUNC_OP = [ + torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d +] +EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] +LINEAR_MODULE_OP = [torch.nn.Linear] +LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] +BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] +LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] +POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] +NON_PARAM_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + torch.flatten, + torch.where, + operator.pow, + torch.pow, + torch.tanh, + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + # softmax should not be here + torch.nn.functional.softmax +] + +INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..239d02115d0e61a75584383505db5f9ecd2fd039 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py @@ -0,0 +1,172 @@ +from typing import List +import math +from torch.fx.node import Node +from .constants import INFINITY_COST + + +class CostGraph: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies, simplify=True): + self.leaf_strategies = leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + # stores number of strategies in each node + self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.following_dict = {} + self.simplify = simplify + self._build_cost_graph() + + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + if self.simplify: + self.merge_pair = [] + for strategies_vector in self.leaf_strategies: + # build edge_cost + dst_node = strategies_vector.node + for src_node in strategies_vector.predecessor_nodes: + if src_node not in self.nodes: + continue + node_pair = (src_node, dst_node) + # src_index = strategies_vector.predecessor_nodes.index(src_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(src_node.strategies_vector)): + edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j] + self.edge_costs[node_pair] = edge_cost + # add parents and children attribute to node + setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) + setattr(dst_node, 'children', strategies_vector.successor_nodes) + self._remove_invalid_node(dst_node, 'parents') + self._remove_invalid_node(dst_node, 'children') + + if self.simplify and strategies_vector.check_merge(): + for followed_node in strategies_vector.predecessor_nodes: + self.merge_pair.append((followed_node, dst_node)) + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + def merge_node(self, src_node, dst_node): + ''' + To merge dst_node into src_node, we need to do it in following steps: + + 1. For each strategy in dst_node, we need to pick an appropriate strategy + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node + strategies dispatching. For example, for the graph 0->1->2, after merging node 1 + into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] + x represents the picking strategy of node 1 merged into node 2 strategy 0. + + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs + contains two parts, one is resharding costs between src_node strategy and dst_node strategy, + another is the origin extra costs in src_node strategy. + + 3. Build connections between new node pairs, and remove the src_node after all consumer nodes + detached from it. + + Argument: + src_node(Node): The node will be merged into dst_node. + dst_node(Node): The node to integrate src_node. + ''' + src_node_index = dst_node.parents.index(src_node) + # build merge_map + merge_map = {} + for src_index, strategy in enumerate(src_node.strategies_vector): + min_cost = INFINITY_COST + lowest_cost_index = -1 + for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): + resharding_cost = dst_strategy.resharding_costs[src_node][src_index] + if resharding_cost <= min_cost: + min_cost = resharding_cost + lowest_cost_index = dst_index + merge_map[src_index] = lowest_cost_index + + # extra_node_cost for src node + self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] + for src_index, strategy in enumerate(src_node.strategies_vector): + target_strate_index = merge_map[src_index] + target_strategy = dst_node.strategies_vector[target_strate_index] + self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index] + if dst_node in self.extra_node_costs: + self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] + + # add new node pair to cost graph + for child_node in dst_node.children: + new_node_pair = (src_node, child_node) + old_node_pair = (dst_node, child_node) + if new_node_pair in self.edge_costs: + continue + edge_cost = {} + for i in range(self.node_lens[src_node]): + for j in range(self.node_lens[child_node]): + dst_strate_index = merge_map[i] + # dst_strategy = dst_node.strategies_vector[dst_strate_index] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] + if new_node_pair not in self.edge_costs: + self.edge_costs[new_node_pair] = edge_cost + else: + # we should accumulate the resharding costs if args of child node contain + # both src node and dst node. + for index_pair, resharding_cost in self.edge_costs[new_node_pair]: + self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] + + # connect src node and children of dst node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + self.edge_costs.pop((src_node, dst_node)) + for child_node in dst_node.children: + if child_node not in src_node.children: + src_node.children.append(child_node) + if src_node not in child_node.parents: + child_node.parents.append(src_node) + # remove dst node from cost graph when dst node has no producer. + if len(dst_node.parents) == 0: + child_node.parents.remove(dst_node) + node_pair = (dst_node, child_node) + self.edge_costs.pop(node_pair) + if len(dst_node.parents) == 0: + self.following_dict[dst_node] = src_node + dst_node.children = [] + + def _reindexing_src(self, src): + if src not in self.following_dict: + return src + return self._reindexing_src(self.following_dict[src]) + + def simplify_graph(self): + if not self.simplify: + return + self.merge_pair.reverse() + for (src_node, dst_node) in self.merge_pair: + self.merge_node(src_node, dst_node) + self.merge_pair.reverse() + reindexing_following_dict = {} + for dst, src in self.following_dict.items(): + reindexing_following_dict[dst] = self._reindexing_src(src) + self.following_dict = reindexing_following_dict diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..831e7eadd179dec27a4ab7ebd28fc5bb53edc4dd --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py @@ -0,0 +1,163 @@ +from dataclasses import dataclass +from torch.fx.node import Node +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from collections import OrderedDict as ODict +from typing import List, OrderedDict, Union, Any +from colossalai.fx.passes.utils import get_node_module + +__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] + + +@dataclass +class LiveVariable: + """ + LiveVariable is a data structure to store the meta information of a variable for liveness analysis. + """ + name: str + node: Node + is_inplace: bool + + +class LiveVariableVector(list): + """ + LiveVariableVector is a data structure to store the list of LiveVariable objects. + """ + + def exists(self, name) -> bool: + """ + Check if a variable has already existed in the current list by name. + """ + for var in self: + if name == var.name: + return True + return False + + def get(self, name) -> LiveVariable: + for var in self: + if name == var.name: + return var + raise KeyError(f"Variable {name} is not found") + + def copy(self) -> "LiveVariableVector": + """ + Create a copy of this vector + """ + vector = LiveVariableVector() + for var in self: + vector.append(var) + return vector + + +@dataclass +class LiveStage: + """ + LiveStage is a data structure to record the living variables at this current node. + """ + name: str + node: Node + all_live_vars: LiveVariableVector + unique_live_vars: LiveVariableVector + + +class GraphAnalyser: + + def __init__(self, gm: GraphModule): + self._gm = gm + self._graph = gm.graph + + @property + def gm(self) -> GraphModule: + """ + Return the GraphModule object associated with this analyser. + """ + return self._gm + + @property + def graph(self) -> Graph: + """ + Return the Graph object associated with this analyser. + """ + return self._graph + + def liveness_analysis(self) -> List[LiveStage]: + """ + Analyse the graph to obtain the variable liveness information. This function returns + an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. + """ + compute_nodes = self.graph.nodes + liveness_list = [] + + # checked: record all variables created since the first stage + # all: record the live variables only exist until the current stage. + # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. + # unique: record the unique live variables only exist until the current stage. + # this is different from `all list` as some variables are duplicated. + checked_variables = LiveVariableVector() + all_live_variables = LiveVariableVector() + unique_live_vars = LiveVariableVector() + + for idx, node in enumerate(compute_nodes): + ############################# + # find new living variables # + ############################# + # detect whether the current op is an in-place op + # if it is an in-place op, we would deem it as a duplciate var + is_inplace = False + if node.op == 'call_function': + # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) + if node.kwargs.get('inplace', False): + is_inplace = True + elif node.op == 'call_module': + # to check if this is an inplace op such as torch.nn.Relu(inplace=True) + module = get_node_module(node) + if getattr(module, 'inplace', False): + is_inplace = True + + # add the output var + meta = getattr(node, '_meta_data', None) + live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) + if not is_inplace: + unique_live_vars.append(live_var) + checked_variables.append(live_var) + all_live_variables.append(live_var) + + # check if any input is not checked yet + for arg in node.args: + if not isinstance(arg, Node): + continue + arg_name = arg.name + if not checked_variables.exists(arg_name): + live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False) + all_live_variables.append(live_var_from_arg) + checked_variables.append(live_var_from_arg) + unique_live_vars.append(live_var_from_arg) + + # TODO: add the logic to remove live variables + # this should be completed if we are able to trace the backward compute graph + + # add this stage to liveness dict + stage = LiveStage(name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy()) + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + replace = False + for index, prev_stage in enumerate(liveness_list): + all_covered = True + for ele in prev_stage.unique_live_vars: + if ele not in stage.unique_live_vars: + all_covered = False + break + if all_covered: + replace = True + break + if replace: + liveness_list[index] = stage + else: + liveness_list.append(stage) + + return liveness_list + + def get_alias_set(self): + pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..723e1bcf95ed827f54b428e06106481efa0c22d9 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py @@ -0,0 +1,15 @@ +from .batch_norm_handler import BatchNormHandler +from .bcast_op_handler import BcastOpHandler +from .conv_handler import ConvHandler +from .dot_handler import DotHandler +from .embedding_handler import EmbeddingHandler +from .layer_norm_handler import LayerNormHandler +from .operator_handler import OperatorHandler +from .reshape_handler import ReshapeHandler +from .unary_elementwise_handler import UnaryElementwiseHandler +from .where_handler import WhereHandler + +__all__ = [ + 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', + 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler' +] diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..51943627082878cfb86c253891ee48347a4f9166 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py @@ -0,0 +1,492 @@ +import operator +from functools import reduce + +import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from .operator_handler import OperatorHandler + +__all__ = ['BatchNormHandler'] + + +class BatchNormHandler(OperatorHandler): + """ + A OperatorHandler which deals with the sharding strategies of normalization. + + To keep the math consistency, there are two way to do BatchNorm if the input + shards on batch dimension: + 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. + 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help + us to keep the computing correctness. + In this handler, both methods will be considered. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.bias = self.module_named_parameters['bias'] + self.output_data = self.node._meta_data + self._sanity_check() + + def _sanity_check(self): + ''' + In sanity check, we need make sure the input data having correct dimension size. + For BatchNorm1d, the dim of input data should be 3([N, C, L]). + For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). + For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + assert self.input_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def _generate_compute_cost(self, bs, channel_in): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + + Argument: + bs(int): Batch size of the input data. + channel_in(int): The channel dimension of input data. + + Return: + compute_cost(float): Computation cost per device with this specific strategy + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # TODO: a constant coefficient need to be added. + # 1D: (L) * N * Cin + # 2D: (H * W) * N * Cin + # 3D: (H * W * D) * N * Cin + + input_size = self.input_data.shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) + forward_compute_cost = input_size_product * bs * channel_in + backward_activation_compute_cost = input_size_product * bs * channel_in + backward_weight_compute_cost = input_size_product * bs * channel_in + backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost + compute_cost = forward_compute_cost + backward_compute_cost + return compute_cost + + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + numel_input = numel_output + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation + + @ignore_sharding_exception + def split_input_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + # shard the output batch dimension to get all possible sharding strategy from this basic strategy + new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + + dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]} + new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + # the computation cost is all the same + new_compute_cost = compute_cost + + # the memory cost need to be recomputed + # compute the memroy cost of new strategy + new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # the communication cost need to count the sharding cost into this strategy + # compute the communication cost of new strategy + origin_communication_cost = communication_cost + tiny_shard_cost = 10 + new_forward_communication_cost = tiny_shard_cost + # we need to all gather the batch dimension for the basic strategy + new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1) + new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost + + sharding_strategies = ShardingStrategy(new_name, + output_sharding_spec=new_sharding_spec_for_output, + compute_cost=new_compute_cost, + communication_cost=new_communication_cost, + memory_cost=new_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * + self.device_mesh.shape[mesh_dim_1]) + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def non_split(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RR x R' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = 1 + sharding_size_weight = 1 + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def _construct_batch_sharding_strategies(mesh_dim_list, new_name): + dim_partition_dict_for_output = {0: mesh_dim_list} + new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # the computation cost is all the same + new_compute_cost = compute_cost + + # the memory cost need to be recomputed + new_sharding_size_input = 1 + for mesh_dim in mesh_dim_list: + new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim] + new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight) + + # the communication cost need to count the sharding cost into this strategy + origin_communication_cost = communication_cost + tiny_shard_cost = 10 + new_forward_communication_cost = tiny_shard_cost + if len(mesh_dim_list) == 1: + new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, + mesh_dim_list[0]) + else: + new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost( + memory_cost_backward_activation, 0) + new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost + + new_sharding_strategy = ShardingStrategy(new_name, + output_sharding_spec=new_sharding_spec_for_output, + compute_cost=new_compute_cost, + communication_cost=new_communication_cost, + memory_cost=new_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, + sharding_spec_for_weight)) + + return new_sharding_strategy + + # shard the output batch dimension to get all possible sharding strategy from this basic strategy + # shard on mesh_dim_0 + new_name = f'S{mesh_dim_0}R = RR x R' + mesh_dim_list = [mesh_dim_0] + new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) + self.strategies_vector.append(new_sharding_strategy) + + # shard on mesh_dim_1 + new_name = f'S{mesh_dim_1}R = RR x R' + mesh_dim_list = [mesh_dim_1] + new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) + self.strategies_vector.append(new_sharding_strategy) + + # shard on mesh_dim_0, mesh_dim_1 + new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R' + mesh_dim_list = [mesh_dim_0, mesh_dim_1] + new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) + self.strategies_vector.append(new_sharding_strategy) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = 1 + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # the all reduce communication will happen during the sync bn computing. + communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + + dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) + channel_in = self.input_data.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = 1 + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # the all reduce communication will happen during the sync bn computing. + communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # the all reduce communication will happen during the sync bn computing. + communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + + Example: + norm_handler = BatchNormHandler(node, strategies_vector, + self.shape_consistency_manager) + norm_handler.register_strategy() + for strategy in norm_handler.strategies_vector: + print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + + Output: + RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 + RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 + RR = RR x R, computation_cost: 262144, memory_cost: 1048576 + RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0 + ''' + + # RS = RS x S and strategies based on it, such as + # SS = RS x S + self.split_input_channel(0, 1) + self.split_input_channel(1, 0) + + # RR = RR x R and strategies based on it, such as + # SR = SR x R + self.non_split(0, 1) + + # RS01 = RS01 x S01 + self.split_input_channel_1d(0, 1) + + # SR = SR x R WITH SYNC_BN + self.split_input_batch(0) + self.split_input_batch(1) + + # SS = SS x S WITH SYNC_BN + self.split_input_both_dim(0, 1) + self.split_input_both_dim(1, 0) + + # S01R = S01R x R WITH SYNC_BN + self.split_input_batch_1d(0, 1) + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac6dce7667504bb15bfc5efb1e68a50fe23bf0d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py @@ -0,0 +1,552 @@ +import operator +import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + +import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from .operator_handler import OperatorHandler + +__all__ = ['BcastOpHandler'] + + +class BcastOpHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert len(self.predecessor_node) == 2 + self.lhs_data = self.predecessor_node[0]._meta_data + self.rhs_data = self.predecessor_node[1]._meta_data + self.lhs = self.predecessor_node[0] + self.rhs = self.predecessor_node[1] + self.output_data = self.node._meta_data + + def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + shape = list(input_.shape) + + # padding the shape to the same length as output_data + while len(shape) < self.output_data.dim(): + shape.insert(0, 1) + shape = torch.Size(shape) + + # if the sharding happens on a size one dimension, we should record it as R. + processed_dim_partition_dict = deepcopy(dim_partition_dict) + for dim_index, _ in dim_partition_dict.items(): + if shape[dim_index] == 1: + processed_dim_partition_dict.pop(dim_index) + for dim_index, sharding_index_list in processed_dim_partition_dict.items(): + sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=shape, + dim_partition_dict=processed_dim_partition_dict) + + return sharding_spec + + def _generate_compute_cost(self, total_sharding_size): + lhs_matrix_shape = self.lhs_data.shape[-2:] + rhs_matrix_shape = self.rhs_data.shape[-2:] + batch_dimensions_shape = self.output_data.shape[:-2] + batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1) + compute_cost = reduce( + operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size + return compute_cost + + def _generate_resharding_costs(self, sharding_specs): + # The resharding_cost of weight is counted due to sharing weight cases. + dtype = self.node._meta_data.dtype + nodes = self.predecessor_node + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + # if the input shape is smaller than the target input, we will fill the input to the same length as target. + # Then, use the padded input sharding spec to compute the resharding cost. + if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape): + new_entire_shape = list(input_sharding_spec.entire_shape) + while len(new_entire_shape) < len(input_spec.entire_shape): + new_entire_shape.insert(0, 1) + new_entire_shape = torch.Size(new_entire_shape) + new_device_mesh = input_sharding_spec.device_mesh + new_dim_partition_dict = input_sharding_spec.dim_partition_dict + input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh, + entire_shape=new_entire_shape, + dim_partition_dict=new_dim_partition_dict) + + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes + resharding_costs[input_node].append(resharding_cost) + + return resharding_costs + + def _convert_partition_dict_to_sharding_spec(self, dim_partition_list): + + sharding_spec_list = [] + check_duplicated_list = [] + for output_dim_partition_dict in dim_partition_list: + try: + output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) + except AssertionError as e: + warnings.warn(f'{e}') + break + sharding_seq = output_sharding_spec.sharding_sequence + if sharding_seq not in check_duplicated_list: + check_duplicated_list.append(sharding_seq) + sharding_spec_list.append(output_sharding_spec) + + return sharding_spec_list + + def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity. + + output_dim_partition_list = [] + dim_size = self.output_data.dim() + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + output_dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + output_dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + output_dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + output_dim_partition_list.append({}) + output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list) + + return output_sharding_spec_list + + @ignore_sharding_exception + def _register_strategy(self, output_sharding_spec): + dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input) + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input) + + name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the computation cost of this strategy + sharding_dims = [] + for mesh_dims in dim_partition_dict_for_output.values(): + for mesh_dim in mesh_dims: + sharding_dims.append(self.device_mesh.shape[mesh_dim]) + sharding_size = reduce(operator.mul, sharding_dims, 1) + memory_cost = self.output_data.numel() / sharding_size + compute_cost = memory_cost + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=output_sharding_spec, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + ############################################## + #used to generate strategies for torch.matmul# + ############################################## + @ignore_sharding_exception + def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim): + # this dim partition dict only describes the batch dimensions, but in this scenario, + # matrix dimensions are fully replicated, so it do not need extra process. + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim) + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + batch_sharding_dims = [] + for mesh_dims in dim_partition_dict_for_batch_dim.values(): + for mesh_dim in mesh_dims: + batch_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1) + # in this case, total_sharding_size is equal to the batch sharding size + memory_cost = self.output_data.numel() / batch_sharding_size + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(batch_sharding_size) + + # in this case, no communication takes place. + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): + # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] + # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. + # In this scenario, matrix dimensions will be sharded on 'i' dimension. + + # in this case, the matrix dimensions of lhs is sharded on 'i' dimension. + dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim) + dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix}) + + # in this case, the matrix dimensions of rhs is fully replicated. + dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim) + + # in this case, the matrix dimensions of output is sharded on 'i' dimension. + + dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim) + dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix}) + + # generate sharding specs + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + total_sharding_dims = [] + + # append batch sharding dims + for mesh_dims in dim_partition_dict_for_batch_dim.values(): + for mesh_dim in mesh_dims: + total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + + # append the sharding dims on matrix dimension + for mesh_dim in mesh_dim_on_matrix: + total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + total_sharding_size = reduce(operator.mul, total_sharding_dims, 1) + + # in this case, output_data uses all the sharding dims. + memory_cost = self.output_data.numel() / total_sharding_size + compute_cost = self._generate_compute_cost(total_sharding_size) + + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): + # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] + # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. + # In this scenario, matrix dimensions will be sharded on 'k' dimension. + + # in this case, the matrix dimensions of lhs is sharded on 'k' dimension. + dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim) + dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix}) + + # in this case, the matrix dimensions of rhs is sharded on 'k' dimension. + dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim) + dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix}) + + # in this case, the matrix dimensions of output is fully replicated. + dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim) + + # generate sharding specs + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + total_sharding_dims = [] + batch_sharding_dims = [] + # append batch sharding dims + for mesh_dims in dim_partition_dict_for_batch_dim.values(): + for mesh_dim in mesh_dims: + total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + batch_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + + # append the sharding dims on matrix dimension + for mesh_dim in mesh_dim_on_matrix: + total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1) + total_sharding_size = reduce(operator.mul, total_sharding_dims, 1) + + # in this case, output_data is fully replicated on matrix dimensions. + memory_cost = self.output_data.numel() / batch_sharding_size + + compute_cost = self._generate_compute_cost(total_sharding_size) + + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + # The communication takes place during forward activation computation. + if len(mesh_dim_on_matrix) == 1: + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0]) + else: + communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0) + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): + # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] + # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. + # In this scenario, matrix dimensions will be is sharded on 'j' dimension. + + # in this case, the matrix dimensions of lhs is fully replicated. + dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim) + + # in this case, the matrix dimensions of rhs is sharded on 'j' dimension. + dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim) + dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix}) + + # in this case, the matrix dimensions of output is sharded on 'j' dimension. + dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim) + dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix}) + + # generate sharding specs + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + total_sharding_dims = [] + + # append batch sharding dims + for mesh_dims in dim_partition_dict_for_batch_dim.values(): + for mesh_dim in mesh_dims: + total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + + # append the sharding dims on matrix dimension + for mesh_dim in mesh_dim_on_matrix: + total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) + total_sharding_size = reduce(operator.mul, total_sharding_dims, 1) + + # in this case, output_data uses all the sharding dims. + memory_cost = self.output_data.numel() / total_sharding_size + compute_cost = self._generate_compute_cost(total_sharding_size) + + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + # The communication takes place during backward activation computation. + if len(mesh_dim_on_matrix) == 1: + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0]) + else: + communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0) + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list): + self._split_dim_i(dim_partition_dict, mesh_dim_list) + self._split_dim_k(dim_partition_dict, mesh_dim_list) + self._split_dim_j(dim_partition_dict, mesh_dim_list) + + @ignore_sharding_exception + def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) + + dim_partition_dict_for_rhs = {-2: [mesh_dim_1]} + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) + + dim_partition_dict_for_output = {-2: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1) + output_sharding_size = reduce(operator.mul, self.output_data.shape, 1) + # in this case, output_data uses all the sharding dims. + memory_cost = self.output_data.numel() / output_sharding_size + compute_cost = self._generate_compute_cost(total_sharding_size) + + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + # The communication takes place during forward activation computation. + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + dim_partition_dict_for_lhs = {-1: [mesh_dim_0]} + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) + + dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) + + dim_partition_dict_for_output = {-1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1) + output_sharding_size = reduce(operator.mul, self.output_data.shape, 1) + # in this case, output_data uses all the sharding dims. + memory_cost = self.output_data.numel() / output_sharding_size + compute_cost = self._generate_compute_cost(total_sharding_size) + + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + # The communication takes place during forward and backward activation computation. + communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0) + communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + communication_cost = communication_cost_backward_activation + communication_cost_forward_activation + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + dim_partition_dict_for_lhs = {-2: [mesh_dim_0]} + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) + + dim_partition_dict_for_rhs = {-1: [mesh_dim_1]} + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) + + dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the memory cost of this strategy + total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1) + output_sharding_size = reduce(operator.mul, self.output_data.shape, 1) + # in this case, output_data uses all the sharding dims. + memory_cost = self.output_data.numel() / output_sharding_size + compute_cost = self._generate_compute_cost(total_sharding_size) + + # TODO: add all-reduce cost if lhs or rhs is type of Parameters. + # The communication takes place during backward activation computation. + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + def _registry_2d_strategies_for_matmul(self): + self._split_lhs_space_both_contract(0, 1) + self._split_lhs_space_both_contract(1, 0) + self._split_rhs_space_both_contract(0, 1) + self._split_rhs_space_both_contract(1, 0) + self._split_lhs_space_rhs_space(0, 1) + self._split_lhs_space_rhs_space(1, 0) + + def register_strategy(self) -> StrategiesVector: + MESH_DIM_LIST = [0, 1] + if self.node.target != torch.matmul: + output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1]) + for output_sharding_spec in output_sharding_specs: + self._register_strategy(output_sharding_spec) + else: + # we only care about the non-computing dimensions, + # therefore, we omit the last two dimensions. + dim_size = self.output_data.dim() - 2 + + # Both device mesh axises are uesd on batch dimensions + dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size) + for dim_partition_dict in dim_partition_dicts_2d: + self._registry_no_split_strategies_for_matmul(dim_partition_dict) + + # Only one device mesh axis is uesd on batch dimensions + for mesh_dim_index in [0, 1]: + dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size) + for dim_partition_dict in dim_partition_dicts_1d: + self._registry_no_split_strategies_for_matmul(dim_partition_dict) + self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]]) + + # No device mesh axis is uesd on batch dimensions + dim_partition_dict_on_batch_dim = {} + self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim) + self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST) + self._registry_2d_strategies_for_matmul() diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d8952040dffe4c8251bbe55bd39998a07ffd278d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py @@ -0,0 +1,609 @@ +import operator +import warnings +from functools import reduce + +import torch + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector + +from .operator_handler import OperatorHandler + +__all__ = ['ConvHandler'] + + +class ConvHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of Convolution. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.output_data = self.node._meta_data + self._sanity_check() + + def _sanity_check(self): + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Conv1d, the dim of input data should be 3([N, C, L]). + For Conv2d, the dim of input data should be 4([N, C, H, W]). + For Conv3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + assert self.input_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def _generate_compute_cost(self, bs, channel_in, channel_out): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + + Argument: + bs(int): Batch size of the input data. + channel_in(int): The channel dimension of input data. + channel_out(int): The out channel of the conv weight. + + Return: + compute_cost(float): Computation cost per device with this specific strategy + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (L) * N * Cout * Cin * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + output_size = self.output_data.shape[2:] + output_size_product = reduce(operator.mul, output_size, 1) + input_size = self.input_data.shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) + kernel_size = self.weight.shape[2:] + kernel_size_product = reduce(operator.mul, kernel_size, 1) + forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product + backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product + backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product + compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost + return compute_cost + + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + numel_input = self.input_data.numel() + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight + + @ignore_sharding_exception + def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # This strategy do not need to do all_reduce operation during forward + communication_cost_forward = 0 + # compute the backward communication cost to all reduce the input activation grad + communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, + mesh_dim_1) + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) + # total communication cost + communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = 1 + memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation in forward phase. + communication_cost_forward = 0 + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) + # compute the total cost + communication_cost = communication_cost_forward + communication_cost_backward_weight + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1) + # This strategy do not need to do all_reduce operation to compute the input activation grad + communication_cost_backward_activation = 0 + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) + # compute total cost + communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] + channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + # compute the communication cost of this strategy during backward phase + communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) + communication_cost = communication_cost_forward + communication_cost_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_in_channel_weight_in_channel(self, mesh_dim_0): + name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + # This strategy do NOT need all_reduce during forward phase + communication_cost_backward = 0 + communication_cost = communication_cost_forward + communication_cost_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_weight_out_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = 1 + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # This strategy do not need to do all_reduce during forward phase + communication_cost_forward = 0 + # compute the communication cost of this strategy during backward phase + communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0) + communication_cost = communication_cost_forward + communication_cost_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = 1 + sharding_size_weight = 1 + memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce in both forward and backward phase + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ + mesh_dim_1] + sharding_size_weight = 1 + memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce in forward phase + communication_cost_forward = 0 + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_backward_weight, 0) + # compute the total communication cost + communication_cost = communication_cost_backward_weight + communication_cost_forward + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * + self.device_mesh.shape[mesh_dim_1]) + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ + mesh_dim_1] + sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # compute communication cost during forward phase + communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_forward_activation, 0) + # This strategy do NOT need do all_reduce during backward phase + communication_cost_backward = 0 + communication_cost = communication_cost_forward + communication_cost_backward + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. + + Example: + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, conv, output] + nodes = [node for node in gm.graph.nodes] + + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) + conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], + device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) + conv_handler.register_strategy_into_strategies_vector() + for strategy in conv_handler.strategies_vector: + print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') + + Output: + S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} + S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} + S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} + S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} + S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]} + S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]} + RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} + RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} + RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} + RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} + RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]} + RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]} + ''' + # SS = SR x RS + self.split_input_batch_weight_out_channel(0, 1) + self.split_input_batch_weight_out_channel(1, 0) + + # SR = SR x RR + self.split_input_batch(0) + self.split_input_batch(1) + + # SR = SS x SR + self.split_input_both_dim_weight_in_channel(0, 1) + self.split_input_both_dim_weight_in_channel(1, 0) + + # RS = RS x SS + self.split_input_in_channel_weight_both_channel(0, 1) + self.split_input_in_channel_weight_both_channel(1, 0) + + # RR = RS x SR + self.split_input_in_channel_weight_in_channel(0) + self.split_input_in_channel_weight_in_channel(1) + + # RS = RR x RS + self.split_weight_out_channel(0) + self.split_weight_out_channel(1) + + # RR= RR x RR + self.non_split() + + # S01R = S01R x RR + self.split_1d_parallel_on_input_batch(0, 1) + + # RR = RS01 x S01R + self.split_1d_parallel_on_in_channel(0, 1) + + return self.strategies_vector + + +CONV_STRATEGIES_LIST = [ + 'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', + 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', + 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R' +] diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..4feeacd983df7698580e37221f80f0855f8f4418 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py @@ -0,0 +1,756 @@ +import operator +from enum import Enum +from functools import reduce +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP +from .operator_handler import OperatorHandler +from .strategy_generator import IntermediateStrategy, StrategyGenerator + +__all__ = ['DotHandler'] + + +class DotProductStrategyGenerator(StrategyGenerator): + """ + DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation. + This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we + do not consider bias here. + """ + + def validate(self, input, other): + assert input.dim() == 1 and other.dim() == 1 + + def no_split(self): + name = f'R = R dot R' + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_one_dim(self, mesh_dim): + name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}' + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) + + def generate(self) -> List[IntermediateStrategy]: + strategy_list = [] + + # do not split dimensions for dot product + # R = R dot R + strategy_list.append(self.no_split()) + + # split two tensors in the same dimensions + # S = S dot S + strategy_list.append(self.split_one_dim(0)) + strategy_list.append(self.split_one_dim(1)) + + return strategy_list + + +class MatVecStrategyGenerator(StrategyGenerator): + + def validate(self, input, other) -> bool: + assert input.dim() > 1 and other.dim() == 1 + + def no_split(self): + name = "R = R x R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_input_batch(self, mesh_dim): + name = f'S{mesh_dim}R = S{mesh_dim}R x R' + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def generate(self) -> List[IntermediateStrategy]: + strategy_list = [] + + # no split + strategy_list.append(self.no_split()) + + # split the batch dim for the first tensor only + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + return strategy_list + + +class MatMulStrategyGenerator(StrategyGenerator): + """ + MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is + a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. + + A matmul can be formulated as [n, p] x [p, q] = [n, q] + + Args: + is_linear (bool): whether this generator is used for nn.Linear and F.linear. + This will incur extra transformation of the dim partitioning as the weight is transposed. + """ + + def __init__(self, is_linear: bool, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_linear = is_linear + + # as the weight for the linear module is transposed, we can compute + # the correponding dimension indexfor convenience + if is_linear: + self.dim_q = 0 + self.dim_p = 1 + else: + self.dim_q = 1 + self.dim_p = 0 + + def validate(self, input, other, bias) -> bool: + # make sure the second tensor is a 2D tensor + assert input.dim() > 0 and other.dim() == 2 + + # make sure bias is of the same dimension + if self.is_linear: + assert bias is None or bias.shape[-1] == other.shape[0] + else: + assert bias is None or bias.shape[-1] == other.shape[1] + + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + self.dim_q: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + self.dim_p: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1]) + + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + dim_partition_dict = { + "input": { + -1: [mesh_dim_0] + }, + "other": { + self.dim_p: [mesh_dim_0], + self.dim_q: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + -1: [mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + dim_partition_dict = { + "input": { + -1: [mesh_dim] + }, + "other": { + self.dim_p: [mesh_dim] + }, + "bias": {}, + "output": {}, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) + + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + dim_partition_dict = { + "input": {}, + "other": { + self.dim_q: [mesh_dim] + }, + "bias": { + -1: [mesh_dim] + }, + "output": { + -1: [mesh_dim] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) + + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + dim_partition_dict = { + "input": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + self.dim_p: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": {}, + } + return IntermediateStrategy(name=name, + dim_partition_dict=dim_partition_dict, + all_reduce_axis=[mesh_dim_0, mesh_dim_1]) + + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict = { + "input": {}, + "other": { + self.dim_q: [mesh_dim_0, mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "output": { + -1: [mesh_dim_0, mesh_dim_1] + }, + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + +class BatchedMatMulStrategyGenerator(StrategyGenerator): + """ + Generate sharding strategies for the batched matrix multiplication. + + A batched matrix multiplication can be viewed as + [b, i, k] x [b, k, j] -> [b, i, j] + """ + + def __init__(self, is_torch_bmm: bool, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_torch_bmm = is_torch_bmm + + def validate(self, input, other, bias) -> bool: + if self.is_torch_bmm: + assert input.shape == other.shape + assert input.dim() > 2 + assert other.shape[-1] == bias.shape[0] + else: + # TODO: validate these inputs are broadcastable + pass + + def split_one_batch_dim(self): + if 1 in self.device_mesh.mesh_shape: + mesh_dim = self.device_mesh.mesh_shape.index(1) + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": { + 0: [mesh_dim] + }, + "bias": {}, + "output": { + 0: [mesh_dim] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + else: + return None + + def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_one_batch_dim(self, mesh_dim): + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0] + }, + "bias": {}, + "output": { + 0: mesh_dim_0, + -2: [mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) + + def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + } + } + return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1]) + + def generate(self) -> List[IntermediateStrategy]: + strategy_list = [] + + # split only the batch dimension + # Sb = Sb x Sb + # can be None as it is only for 1D device mesh + strategy = self.split_one_batch_dim() + if strategy: + strategy_list.append(strategy) + + # split batch dim of two inputs and the i dim of the first tensor + # SbSi = SbSi x Sb + strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) + + # split batch dim of two inputs and the j of the second tensor + # SbSj = Sb x SbSj + strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) + + # split batch dim of two inputs and the k dim of two inputs + # Sb = SbSk x SbSk, need to all-reduce by k dim + strategy_list.append(self.split_batch_dim_both_contract(0, 1)) + strategy_list.append(self.split_batch_dim_both_contract(1, 0)) + + # split two batch dim + strategy_list.append(self.split_two_batch_dim(0, 1)) + strategy_list.append(self.split_two_batch_dim(1, 0)) + + return strategy_list + + +class DotHandler(OperatorHandler): + """ + A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size): + # TODO: consider bias addition + compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size + return compute_cost + + @ignore_sharding_exception + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + # linear layer weight is transposed during init + dim_partition_dict_for_weight = {0: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute computation cost + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost + communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) + communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) + communication_cost = communication_cost_activation_backward + communication_cost_weight_backward + + # create and register strategy + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + # since weight of the linear layer is transposed + # the actual dim to be sharded is 1 + dim_partition_dict_for_weight = {1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost of this strategy + communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) + communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) + communication_cost = communication_cost_activation_forward + communication_cost_grad_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost of this strategy + communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0) + communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1) + communication_cost = communication_cost_activation_backward + communication_cost_activation_forward + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + + dim_partition_dict_for_input = {1: [mesh_dim]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost of this strategy + communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) + communication_cost = communication_cost_activation_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost of this strategy + communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) + communication_cost = communication_cost_weight_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + + # compute the communication cost of this strategy + communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost( + activation_memory_cost, 0) + communication_cost = communication_cost_forward_activation + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) + # compute the communication cost of this strategy + communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost( + input_grad_memory_cost, 0) + communication_cost = communication_cost_activation_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a linear node, and record all strategies into the strategies_vector. + + Output: + + ''' + # SS = SR x RS + self.split_lhs_space_rhs_space(0, 1) + self.split_lhs_space_rhs_space(1, 0) + + # SR = SS x SR + self.split_lhs_space_both_contract(0, 1) + self.split_lhs_space_both_contract(1, 0) + + # RS = RS x SS + self.split_rhs_space_both_contract(0, 1) + self.split_rhs_space_both_contract(1, 0) + + # RR= RS x SR + self.recompute_split_both_contract(0) + self.recompute_split_both_contract(1) + + # RS = RR x RS + self.split_rhs_space_only(0) + self.split_rhs_space_only(1) + + # S01R = S01R x RR + self.split_lhs_1st_dim_1d(0, 1) + + # RR = RS01 x S01R + self.split_lhs_2nd_dim_1d(0, 1) + + # RS01 = RR x RS01 + self.split_rhs_2nd_dim_1d(0, 1) + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d01a487ad6739b9eb4e23a5107fdcf0ca4a3d7f5 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py @@ -0,0 +1,179 @@ +import operator +import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + +import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from .operator_handler import OperatorHandler + +__all__ = ['EmbeddingHandler'] + + +class EmbeddingHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, total_sharding_size): + input_shape = self.input_data.shape + weight_shape = self.weight.shape + input_shape_product = reduce(operator.mul, input_shape, 1) + weight_shape_product = reduce(operator.mul, weight_shape, 1) + compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size + return compute_cost + + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + numel_input = self.input_data.numel() + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight + + @ignore_sharding_exception + def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {2: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1] + compute_cost = self._generate_compute_cost(total_sharding_size) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = 1 + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + # compute the communication cost of this strategy during backward phase + communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) + communication_cost = communication_cost_forward + communication_cost_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) + + # compute the computation cost of this strategy + total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1] + compute_cost = self._generate_compute_cost(total_sharding_size) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = 1 + memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # This strategy do not need to do all_reduce during forward phase + communication_cost_forward = 0 + # compute the communication cost of this strategy during backward phase + communication_cost_backward_activation = 0 + communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_backward_weight, 0) + communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight + communication_cost = communication_cost_forward + communication_cost_backward + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. + ''' + # RRS = RR x SS + self.split_weight_both_dim(0, 1) + self.split_weight_both_dim(1, 0) + + # SSR = SS x RR + self.split_input_both_dim(0, 1) + self.split_input_both_dim(1, 0) + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c75fdbbb6dac98721b6f42e57956d2d858e6d502 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py @@ -0,0 +1,237 @@ +import operator +from functools import reduce + +import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + generate_sharding_size, ignore_sharding_exception) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from .operator_handler import OperatorHandler + +__all__ = ['LayerNormHandler'] + + +class LayerNormHandler(OperatorHandler): + """ + A OperatorHandler which deals with the sharding strategies of normalization. + + Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.bias = self.module_named_parameters['bias'] + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, total_sharding_size): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + + Argument: + bs(int): Batch size of the input data. + channel_in(int): The channel dimension of input data. + + Return: + compute_cost(float): Computation cost per device with this specific strategy + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # TODO: a constant coefficient need to be added. + + norm_kernel_size = self.weight.shape + # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. + input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)] + input_batch_product = reduce(operator.mul, input_batch_shape, 1) + norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1) + forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size + backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size + # To compute gradient of on norm kernel element requires input_batch_product times computation, so + # the total cost is input_batch_product * norm_kernel_product + backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size + backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost + compute_cost = forward_compute_cost + backward_compute_cost + return compute_cost + + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + # this operation will not change the shape of input + numel_input = numel_output + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight + + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_for_input = dim_partition + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = dim_partition + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}' + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh) + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(total_sharding_size) + + # compute the memory cost of this strategy + sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh) + sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh) + sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh) + memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition.values(): + total_mesh_dim_list.extend(mesh_dim_list) + + # This strategy do not need to do all_reduce operation for activation + communication_cost_forward_activation = 0 + communication_cost_backward_activation = 0 + if len(total_mesh_dim_list) == 1: + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, + total_mesh_dim_list[0]) + else: + assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.' + communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_backward_weight, 0) + communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @ignore_sharding_exception + def split_input_batch_single_mesh_dim(self, mesh_dim_0): + batch_dimension_length = self.input_data.dim() - self.weight.dim() + dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) + for dim_partition in dim_partition_list: + self._generate_strategy_with_dim_partition(dim_partition) + + @ignore_sharding_exception + def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1): + batch_dimension_length = self.input_data.dim() - self.weight.dim() + dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) + for dim_partition in dim_partition_list: + self._generate_strategy_with_dim_partition(dim_partition) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x R' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + total_sharding_size = 1 + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(total_sharding_size) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = 1 + sharding_size_weight = 1 + memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + + Example: + norm_handler = BatchNormHandler(node, strategies_vector, + self.shape_consistency_manager) + norm_handler.register_strategy() + for strategy in norm_handler.strategies_vector: + print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + + Output: + RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 + RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 + RR = RR x R, computation_cost: 262144, memory_cost: 1048576 + RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0 + ''' + + # SR = SR x R with single mesh dim on batch dimensions + self.split_input_batch_single_mesh_dim(0) + self.split_input_batch_single_mesh_dim(1) + + # SR = SR x R with both mesh dims on batch dimensions + self.split_input_batch_both_mesh_dim(0, 1) + + # RR = RR x R + self.non_split() + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b120cc16b04b5594e519056a15206d0947f53038 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py @@ -0,0 +1,149 @@ +from abc import ABC, abstractmethod +from typing import Dict, List +from webbrowser import Opera + +import torch +import torch.nn as nn +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec + +from .._utils import generate_resharding_costs, generate_sharding_spec +from ..sharding_strategy import StrategiesVector + +__all__ = ['OperatorHandler'] + + +class OperatorHandler(ABC): + ''' + The OperatorHandler is an abstract class used to generate every possible strategies for an operator node. + + Args: + node (Node): the input node in node argument list. + device_mesh (DeviceMesh): A logical view of a physical mesh. + strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference. + ''' + + def __init__(self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + handle_backward: bool = True): + self.node = node + self.predecessor_node = list(node._input_nodes.keys()) + self.successor_node = list(node.users.keys()) + self.device_mesh = device_mesh + self.strategies_vector = strategies_vector + self.handle_backward = handle_backward + + # find the module and its parameters associated with this node + # this can be used to compute the compute/communication/sharding cost + if self.node.op == 'call_module': + module = node.graph.owning_module.get_submodule(node.target) + named_parameters = list(module.named_parameters(recurse=False)) + # convert named parameters from list to dict + named_parameters = {k: v for k, v in named_parameters} + elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP: + module = None + parameters = list(self.node.args)[1] + if isinstance(parameters, Node): + named_parameters = {'weight': parameters._meta_data} + else: + named_parameters = {} + else: + module = None + named_parameters = None + self.module = module + self.module_named_parameters = named_parameters + + @abstractmethod + def register_strategy(self) -> StrategiesVector: + """ + Register + """ + pass + + def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight, + sharding_spec_for_input): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + Return: + total_memory_cost(float): total memory cost per device with this specific strategy + activation_cost(float): the memory cost of activation per device with this specific strategy + weight_memory_cost(float): the memory cost of weight per device with this specific strategy + ''' + # compute the size of one element with specific dtype + dtype = self.input_data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # compute the memory cost of activation + activation_numel = self.output_data.numel() + output_mesh_dims = [] + for sharding_dim, mesh_dims in dim_partition_dict_for_output.items(): + output_mesh_dims.extend(mesh_dims) + activation_sharding_size = 1 + for mesh_dim in output_mesh_dims: + activation_sharding_size *= self.device_mesh.shape[mesh_dim] + activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes + + # compute the memory cost of weight + weight_numel = self.weight.numel() + weight_sharding_size = 1 + weight_mesh_dims = [] + for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items(): + weight_mesh_dims.extend(mesh_dims) + for mesh_dim in weight_mesh_dims: + weight_sharding_size *= self.device_mesh.shape[mesh_dim] + weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes + + # compute the memory cost of input grad + input_grad_numel = self.input_data.numel() + input_grad_sharding_size = 1 + input_grad_mesh_dims = [] + for sharding_dim, mesh_dims in sharding_spec_for_input.items(): + input_grad_mesh_dims.extend(mesh_dims) + for mesh_dim in input_grad_mesh_dims: + input_grad_sharding_size *= self.device_mesh.shape[mesh_dim] + input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes + + memory_cost_forward = activation_memory_cost + weight_memory_cost + memory_cost_backward = input_grad_memory_cost + weight_memory_cost + + return (memory_cost_forward, + memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost + + def _generate_resharding_costs(self, sharding_specs): + # The resharding_cost of weight is counted due to sharing weight cases. + if hasattr(self.node._meta_data, 'dtype'): + dtype = self.node._meta_data.dtype + else: + assert isinstance(self.node._meta_data, + tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected' + dtype = self.node._meta_data[0].dtype + + nodes = self.predecessor_node + return generate_resharding_costs(nodes=nodes, + sharding_specs=sharding_specs, + count_backward=self.handle_backward, + dtype=dtype) + + def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + return generate_sharding_spec(input_=input_, + device_mesh=self.device_mesh, + dim_partition_dict=dim_partition_dict) + + @abstractmethod + def _generate_compute_cost(self, *args, **kwargs): + """ + Compute the flops involved in the node. + """ + pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..2d39670256f56abf6181dbac28f5ca747e77b744 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py @@ -0,0 +1,89 @@ +import colorsys +import math +import warnings +from copy import deepcopy + +import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..constants import INFINITY_COST +from .operator_handler import OperatorHandler + + +class ReshapeHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, *args, **kwargs): + return super()._generate_compute_cost(*args, **kwargs) + + @ignore_sharding_exception + def register_strategy(self): + # TODO: add strategies with more output sharding specs other than only fully replicated. + input_node = self.strategies_vector.predecessor_nodes[0] + # For reshape function, to keep the computing correctness we keep the sharding + # spec of input is fully replicated. In addition, we will keep the output in + # replica status and let the successor node choose the way to resharding the + # output node. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for reshape function. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict_for_output = {} + if isinstance(self.output_data, tuple): + dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))] + try: + if isinstance(self.output_data, tuple): + output_sharding_spec = [] + for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output): + output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict)) + else: + output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + except AssertionError as e: + warnings.warn(f'{e}') + continue + name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED' + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = 0 + # consider node._meta_data is in type of tuple + memory_cost = 0 + + # compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating. + dim_partition_dict_for_replicate_input = {} + replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data, + dim_partition_dict_for_replicate_input) + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + _, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec, + replicate_input_sharding_spec) + communication_cost = communication_cost["total"] + + # generate resharding cost + resharding_costs = self._generate_resharding_costs([input_sharding_spec]) + + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + self.strategies_vector.append(sharding_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..4e39fcd8e82dfdd71a7643504c8399ebcb2cd74b --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import List, Dict +from colossalai.device.device_mesh import DeviceMesh + +__all__ = ['IntermediateStrategy', 'StrategyGenerator'] + + +@dataclass +class IntermediateStrategy: + """ + IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is + to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler. + + Args: + name (str): name of the sharding strategy. + dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping. + all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation. + """ + name: str + dim_partition_dict: Dict[str, Dict[int, List[int]]] + all_reduce_axis: List[int] = None + + +class StrategyGenerator(ABC): + """ + StrategyGenerator is used to generate the same group of sharding strategies. + """ + + def __init__(self, device_mesh: DeviceMesh): + self.device_mesh = device_mesh + + @abstractmethod + def generate(self) -> List[IntermediateStrategy]: + """ + """ + pass + + @abstractmethod + def validate(self, *args, **kwargs) -> bool: + """ + Validate if the operands are of desired shape. + If True, means this generator can be used for the current operation. + """ + pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c929d2fade98d4b388ca55bd0f2bd954e6149454 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py @@ -0,0 +1,88 @@ +import math +import operator +import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + +import torch +from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ + ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.constants import \ + INFINITY_COST +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from .operator_handler import OperatorHandler + +__all__ = ['UnaryElementwiseHandler'] + + +class UnaryElementwiseHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.node.op == 'call_module': + target = self.node.target + submod = self.node.graph.owning_module.get_submodule(target) + submod_type = type(submod) + if submod_type == torch.nn.Dropout: + print(f'predecessor nodes of dropout node are {self.predecessor_node}') + input_nodes_len = 0 + for check_node in self.predecessor_node: + if isinstance(check_node._meta_data, torch.Tensor): + input_nodes_len += 1 + assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.' + self.input_data = self.predecessor_node[0]._meta_data + self.input_node = self.predecessor_node[0] + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, *args, **kwargs): + return super()._generate_compute_cost(*args, **kwargs) + + @ignore_sharding_exception + def register_strategy(self): + # TODO: integrate element-wise func and module together + # create sharding strategy for element-wise function + + # For element-wise function, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise function. + + for index, strategy in enumerate(self.input_node.strategies_vector): + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + try: + output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict) + except AssertionError as e: + warnings.warn(f'{e}') + continue + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}' + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = self.output_data.numel() + memory_cost = 0 + + resharding_costs = self._generate_resharding_costs([input_sharding_spec]) + + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[self.input_node] = [ + 0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + self.strategies_vector.append(sharding_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6991e913d463a49da1f3eaa3bbf71736b238133b --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py @@ -0,0 +1,186 @@ +import operator +import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + +import torch + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from .operator_handler import OperatorHandler + +__all__ = ['WhereHandler'] + + +class WhereHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of torch.where. + """ + + def __init__(self, *args, **kwargs): + # TODO: x or y could be scalar + super().__init__(*args, **kwargs) + assert len(self.predecessor_node) == 3 + self.condition_data = self.predecessor_node[0]._meta_data + self.x_data = self.predecessor_node[1]._meta_data + self.y_data = self.predecessor_node[2]._meta_data + self.condition = self.predecessor_node[0] + self.x = self.predecessor_node[1] + self.y = self.predecessor_node[2] + self.output_data = self.node._meta_data + + def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + shape = list(input_.shape) + + # padding the shape to the same length as output_data + while len(shape) < self.output_data.dim(): + shape.insert(0, 1) + shape = torch.Size(shape) + + # if the sharding happens on a size one dimension, we should record it as R. + processed_dim_partition_dict = deepcopy(dim_partition_dict) + for dim_index, _ in dim_partition_dict.items(): + if shape[dim_index] == 1: + processed_dim_partition_dict.pop(dim_index) + for dim_index, sharding_index_list in processed_dim_partition_dict.items(): + sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=shape, + dim_partition_dict=processed_dim_partition_dict) + + return sharding_spec + + def _generate_compute_cost(self, total_sharding_size): + lhs_matrix_shape = self.lhs_data.shape[-2:] + rhs_matrix_shape = self.rhs_data.shape[-2:] + batch_dimensions_shape = self.output_data.shape[:-2] + batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1) + compute_cost = reduce( + operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size + return compute_cost + + def _generate_resharding_costs(self, sharding_specs): + # The resharding_cost of weight is counted due to sharing weight cases. + dtype = self.node._meta_data.dtype + nodes = self.predecessor_node + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + # if the input shape is smaller than the target input, we will fill the input to the same length as target. + # Then, use the padded input sharding spec to compute the resharding cost. + if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape): + new_entire_shape = list(input_sharding_spec.entire_shape) + while len(new_entire_shape) < len(input_spec.entire_shape): + new_entire_shape.insert(0, 1) + new_entire_shape = torch.Size(new_entire_shape) + new_device_mesh = input_sharding_spec.device_mesh + new_dim_partition_dict = input_sharding_spec.dim_partition_dict + input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh, + entire_shape=new_entire_shape, + dim_partition_dict=new_dim_partition_dict) + + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + total_resharding_cost = total_resharding_cost['total'] + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost * size_per_elem_bytes + resharding_costs[input_node].append(resharding_cost) + + return resharding_costs + + def _convert_partition_dict_to_sharding_spec(self, dim_partition_list): + + sharding_spec_list = [] + check_duplicated_list = [] + for output_dim_partition_dict in dim_partition_list: + try: + output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) + except AssertionError as e: + warnings.warn(f'{e}') + break + sharding_seq = output_sharding_spec.sharding_sequence + if sharding_seq not in check_duplicated_list: + check_duplicated_list.append(sharding_seq) + sharding_spec_list.append(output_sharding_spec) + + return sharding_spec_list + + def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity. + + output_dim_partition_list = [] + dim_size = self.output_data.dim() + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + output_dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + output_dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + output_dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + output_dim_partition_list.append({}) + output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list) + + return output_sharding_spec_list + + @ignore_sharding_exception + def _register_strategy(self, output_sharding_spec): + dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict + sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input) + sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input) + sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input) + + name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}' + dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs( + [sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y]) + + # compute the computation cost of this strategy + sharding_dims = [] + for mesh_dims in dim_partition_dict_for_output.values(): + for mesh_dim in mesh_dims: + sharding_dims.append(self.device_mesh.shape[mesh_dim]) + sharding_size = reduce(operator.mul, sharding_dims, 1) + memory_cost = self.output_data.numel() / sharding_size + compute_cost = memory_cost + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=output_sharding_spec, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_condition, sharding_spec_for_x, + sharding_spec_for_y)) + + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + MESH_DIM_LIST = [0, 1] + output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1]) + for output_sharding_spec in output_sharding_specs: + self._register_strategy(output_sharding_spec) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/options.py b/colossalai/auto_parallel/tensor_shard/deprecated/options.py new file mode 100644 index 0000000000000000000000000000000000000000..2d34f5c6447e4c06997d2e40bd12c51c5ba302fa --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/options.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +__all__ = ['SolverOptions'] + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + fast: bool = False diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..d468c858e9a9803a6d0c6d3170f68eb5d1682cc6 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py @@ -0,0 +1,91 @@ +from copy import deepcopy +from dataclasses import dataclass +from abc import ABC, abstractmethod +from enum import Enum +import operator +import torch +from functools import reduce + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from typing import Dict, List, Union, Tuple, Any +from torch.fx.node import Node +from .constants import * + +__all__ = ['ShardingStrategy', 'StrategiesVector'] + + +@dataclass +class ShardingStrategy: + ''' + ShardingStrategy is a structure containing sharding strategies of inputs and output of this node + and costs information using in solver. + + Argument: + name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. + output_sharding_spec(ShardingSpec): ShardingSpec of the output node. + compute_cost(float): Computation cost to complete this strategy.(default to 0) + communication_cost(float): Communication cost to complete this strategy.(default to 0) + memory_cost(float): Memory cost of the output node using this strategy.(default to 0) + resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list + with j-th strategy in its strategies_vector transforms to sharding spec wanted in this + strategy.(default to None) + input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. + ''' + + name: str + # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor. + output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]] + compute_cost: float = 0. + communication_cost: float = 0. + memory_cost: float = 0. + resharding_costs: Dict[Node, List[float]] = None + # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input. + # Therefore, we could process them at the specific op(operator.getitem) + input_shardings: List[ShardingSpec] = None + + +class StrategiesVector(list): + ''' + Each node in fx graph will have a corresponding StrategiesVector, to store all the possible + strategies of the node. + + Argument: + node (Node): node for which the list of sharding strategies are generated. + ''' + + def __init__(self, node: Node): + super().__init__() + self.node = node + # fetch its input and output nodes + # TODO: placeholder input nodes + self.predecessor_nodes = list(node._input_nodes.keys()) + if self.node.op == 'output': + self.predecessor_nodes = list(node._input_nodes.keys())[:1] + self.successor_nodes = list(node.users.keys()) + + def check_merge(self): + merge_label = False + if self.node.op == 'call_module': + target = self.node.target + root_module = self.node.graph.owning_module + submod = root_module.get_submodule(target) + submod_type = type(submod) + # merge elementwise module node into source nodes + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if submod_type in ELEMENTWISE_MODULE_OP: + merge_label = True + + if self.node.op == 'call_function': + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if self.node.target in ELEMENTWISE_FUNC_OP: + merge_label = True + # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. + if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: + merge_label = True + # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. + if self.node.target in RESHAPE_FUNC_OP: + merge_label = True + + return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1d2f3bed5ab7b1e21ad9eadc3f63274ae767a6 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py @@ -0,0 +1,469 @@ +import multiprocessing +import time +import warnings +from typing import Dict + +import numpy as np +from torch.fx.graph import Graph +from torch.fx.node import Node + +from .constants import INFINITY_COST +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .strategies_constructor import StrategiesConstructor + +try: + import pulp + from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum +except: + warnings.warn(f'please install the pulp') + +__all___ = ['Solver'] + + +class Solver: + + def __init__(self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser, + memory_budget: float = -1.0, + solution_numbers: int = 1, + memory_increasing_coefficient: float = 1.3): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.graph = graph + self.strategies_constructor = strategies_constructor + self.cost_graph = cost_graph + self.graph_analyser = graph_analyser + self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + self.strategy_map = self.strategies_constructor.strategy_map + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + self.liveness_list = self.graph_analyser.liveness_analysis() + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + + def _recover_merged_node_strategy(self): + ''' + During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. + Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged + node. + ''' + for node_index, node in enumerate(self.nodes): + if node.strategies_vector.check_merge(): + # the merged node has only one input, and its strategies follow the input sharding strategy + input_strategies_vector = node.args[0].strategies_vector + input_best_strategy_index = self.last_s_val[node_index - 1] + input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec + for strategy_index, strategy in enumerate(node.strategies_vector): + if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: + self.last_s_val[node_index] = strategy_index + break + + def _generate_node_index_dict(self) -> Dict[Node, int]: + node_index_dict = {} + for index, strategies_vector in enumerate(self.leaf_strategies): + node_index_dict[strategies_vector.node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append(self.cost_graph.node_lens[node]) + strategies_len = np.array(strategies_len) + + # prepare following_nodes + following_nodes = self.cost_graph.following_dict + index_following_nodes = {} + for src, target in following_nodes.items(): + src_index = self.node_index_dict[src] + target_index = self.node_index_dict[target] + index_following_nodes[src_index] = target_index + following_nodes = index_following_nodes + for index in range(node_nums): + if index not in following_nodes: + following_nodes[index] = -1 + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + + # prepare liveness_set + liveness_set = self.liveness_list + + # omit alias_set now + alias_set = None + alias_convert_costs = None + + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + extra_node_costs = self.cost_graph.extra_node_costs + for strategies_vector in self.leaf_strategies: + node = strategies_vector.node + for index, strategy in enumerate(strategies_vector): + compute_costs.append(strategy.compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + if node in extra_node_costs: + origin_communication_cost = strategy.communication_cost + extra_node_cost = extra_node_costs[node][index] + communication_cost = origin_communication_cost + extra_node_cost + communication_costs.append(communication_cost) + else: + communication_costs.append(strategy.communication_cost) + # temporarily we just consider the forward memory cost + memory_cost = strategy.memory_cost + if isinstance(memory_cost, tuple): + memory_costs.append(memory_cost[0]) + else: + memory_costs.append(memory_cost) + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array(memory_costs) + + # omit initial value for nodes + s_init_np = None + + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None): + """ + Call the solver with serialized arguments. + """ + + tic = time.time() + + for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + v = [] + pt = 0 + + c = [] + d = [] + m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + assert len(e[idx]) == len(r[idx]) + for element in s: + assert len(element) > 0 + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + for i in range(node_nums): + assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + for i in range(len(E)): + assert len(e[i]) == len(r[i]) + obj += lpDot(e[i], r[i]) + + prob += obj + + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + for liveness_stage in liveness_set: + mem = 0 + for live_variable in liveness_stage.unique_live_vars: + node_index = self.node_index_dict[live_variable.node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + prob += mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] + + # (h) + ###################### + # omit alias set now # + ###################### + + # alias_set = set() + # for (idx, (i, j)) in enumerate(A): + # R = len(s[i]) # noqa + # C = len(s[j]) # noqa + # if (i, j) in alias_set: + # raise ValueError(f"Duplicated edges: {(i, j)}") + + # alias_set.add((i, j)) + # alias_set.add((j, i)) + + # for row in range(len(s[i])): + # for col in range(len(s[j])): + # if v[idx][row * C + col] > 0.5: + # prob += s[i][row] + s[j][col] <= 1 + + verbose = True + + msg = verbose + time_limit = 600 + assert "COIN_CMD" in pulp.listSolvers( + onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) + # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + if verbose: + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" + f"Time: {time.time() - tic}") + print(f"#nodes: {num_nodes}, #edges: {num_edges}") + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget. " + "Please increase the memory budget.") + + # Get and check results + s_val = np.full((node_nums,), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E),), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = list(s_val) + self._recover_merged_node_strategy() + self.last_objective = objective + + if objective > INFINITY_COST: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + return self.last_s_val, e_val, self.last_objective, status + + def call_solver_serialized_args(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..7bebde9d65a04a3494a5041e26107a2fe9dee121 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py @@ -0,0 +1,426 @@ +import builtins +import math +import operator +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx import Graph, Node + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from ._utils import generate_resharding_costs, generate_sharding_spec +from .constants import * +from .op_handler import * +from .options import SolverOptions +from .sharding_strategy import ShardingStrategy, StrategiesVector + +__all__ = ['StrategiesConstructor'] + + +class StrategiesConstructor: + """ + StrategiesConstructor is used to construct the parallelization plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. + """ + + def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.device_mesh = device_mesh + self.leaf_strategies = [] + self.strategy_map = {} + self.solver_options = solver_options + + def remove_duplicated_strategy(self, strategies_vector): + ''' + In build_strategies_and_cost method, we may produce some duplicated strategies. + In this method, we will remove the duplicated strategies depending on the strategies name. + ''' + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + + for strategy in remove_list: + strategies_vector.remove(strategy) + + def _is_bcast_matmul(self, node): + is_bcast_matmul = False + if node.target is torch.matmul and len(node.args) == 2: + lhs_data = node.args[0]._meta_data + rhs_data = node.args[1]._meta_data + if lhs_data.dim() >= 3 and rhs_data.dim() >= 3: + is_bcast_matmul = True + return is_bcast_matmul + + def build_strategies_and_cost(self): + for node in self.nodes: + strategies_vector = StrategiesVector(node) + input_nodes_len = 0 + for check_node in strategies_vector.predecessor_nodes: + if isinstance(check_node._meta_data, torch.Tensor): + input_nodes_len += 1 + # input_nodes_len = len(strategies_vector.predecessor_nodes) + # placeholder node + if node.op == 'placeholder': + # For placeholder nodes, if solver_options.fast is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the placeholder node. + # Otherwise, all the possible sharding spec for the placeholder node will be enumerated. + + if self.solver_options.fast: + # create sharding strategy for placeholder + name = 'Replica Placeholder' + dim_partition_dict = {} + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_placeholder = ShardingStrategy(name, + output_sharding_spec, + memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_placeholder) + + # get_attr node + if node.op == 'get_attr': + # Same as placeholder nodes, if solver_options.fast is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the get_attr node. + # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. + if self.solver_options.fast: + # create sharding strategy for get_attr + name = 'Replica Attribute' + dim_partition_dict = {} + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_attribute) + + # call_module node + if node.op == 'call_module': + + target = node.target + submod = self.root_module.get_submodule(target) + submod_type = type(submod) + + # conv module + if submod_type in CONV_MODULE_OP: + # use ConvHandler to create sharding strategies for conv module node + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) + conv_handler.register_strategy() + + # linear module + elif submod_type in LINEAR_MODULE_OP: + # use DotHandler to create sharding strategies for linear module node + dot_handler = DotHandler(node, self.device_mesh, strategies_vector) + dot_handler.register_strategy() + + # element-wise module + elif submod_type in ELEMENTWISE_MODULE_OP: + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + # BatchNormNd module + elif submod_type in BATCHNORM_MODULE_OP: + # create sharding strategy for element-wise module + norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector) + norm_handler.register_strategy() + # for strategy in norm_handler.strategies_vector: + # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + # assert False + + # MaxPool module + elif submod_type in POOL_MODULE_OP: + # TODO: add sharding constraints on image dimension + # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension + + # create sharding strategy for element-wise module + assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.' + input_node = strategies_vector.predecessor_nodes[0] + # For element-wise module, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise module. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = node._meta_data.numel() + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # embedding module + elif submod_type in EMBEDDING_MODULE_OP: + embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector) + embedding_handler.register_strategy() + + # layernorm module + elif submod_type in LAYERNORM_MODULE_OP: + layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector) + layernorm_handler.register_strategy() + # other module + else: + raise RuntimeError(f'{submod_type} module is NOT supported now.') + + # call_function node + if node.op == 'call_function': + target = node.target + # conv function + if target in CONV_FUNC_OP: + # use ConvHandler to create sharding strategies for conv node + # TODO: the operator_handler does NOT support function node processing now. + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) + conv_handler.register_strategy() + + # linear function + elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node): + # use DotHandler to create sharding strategies for linear node + # TODO: the operator_handler does NOT support function node processing now. + linear_handler = DotHandler(node, self.device_mesh, strategies_vector) + linear_handler.register_strategy() + + # where function + elif target == torch.where: + if input_nodes_len == 1: + # both of x and y are scalar + pass + + elif input_nodes_len == 2: + # one of x or y is type of scalar + pass + + else: + # general case + where_handler = WhereHandler(node, self.device_mesh, strategies_vector) + where_handler.register_strategy() + + # reshape function + elif target in RESHAPE_FUNC_OP: + # use ReshapeHandler to create sharding strategies for rehsape node + reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) + reshape_handler.register_strategy() + + # element-wise function + elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + # bcast op + elif target in BCAST_FUNC_OP: + if isinstance(node._meta_data, torch.Tensor): + bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) + bcast_op_handler.register_strategy() + + # torch.var_mean + elif target == torch.var_mean: + dim = node.kwargs['dim'] + input_tensor_node = strategies_vector.predecessor_nodes[0] + for strategy in input_tensor_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + entire_shape_input = input_sharding_spec.entire_shape + dim_partition_dict_input = input_sharding_spec.dim_partition_dict + name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})' + if dim in dim_partition_dict_input: + # We need to make the action dimension in replicate status + dim_partition_dict_for_input = deepcopy(dim_partition_dict_input) + dim_partition_dict_for_input.pop(dim) + new_input_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_input, + dim_partition_dict=dim_partition_dict_for_input) + entire_shape_output = deepcopy(entire_shape_input) + entire_shape_output.pop(dim) + dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_input) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [new_input_sharding_spec]) + sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[new_input_sharding_spec]) + + else: + entire_shape_output = deepcopy(entire_shape_input) + entire_shape_output.pop(dim) + dim_partition_dict_for_output = deepcopy(dim_partition_dict_input) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partion_dict=dim_partition_dict_input) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + + strategies_vector.append(sharding_strategy) + + # operator.getitem + elif target == operator.getitem: + index = node.args[1] + input_tensor_node = strategies_vector.predecessor_nodes[0] + for strategy in input_tensor_node.strategies_vector: + if isinstance(strategy.output_sharding_spec, ShardingSpec): + input_sharding_spec = strategy.output_sharding_spec + else: + input_sharding_spec = strategy.output_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' + dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) + entire_shape_output = deepcopy(input_sharding_spec.entire_shape) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec], + index=index) + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_tensor_node] = [ + cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[strategy.output_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # torch.arange function + elif target == torch.arange: + name = f'FULLY REPLICATED ARANGE' + entire_shape_output = node._meta_data.shape + dim_partition_dict_for_output = {} + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + memory_cost = node._meta_data.numel() + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=0, + memory_cost=memory_cost) + strategies_vector.append(sharding_strategy) + + # op list to be processed to support gpt2 + elif target in (builtins.getattr, operator.le, torch.addmm): + pass + # other function + else: + raise RuntimeError(f'{target} function is NOT supported now.') + + # call_method node + if node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + if method in (torch.Tensor.size,): + pass + elif method in ELEMENTWISE_METHOD_OP: + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + elif method in RESHAPE_METHOD_OP: + reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) + reshape_handler.register_strategy() + # print(strategies_vector) + # if len(strategies_vector) == 0: + # print(node) + # assert False + else: + raise RuntimeError(f'{method} function is NOT supported now.') + + # output node + if node.op == 'output': + if self.solver_options.fast: + # create sharding strategy for output + name = 'Replica Output' + input_nodes = strategies_vector.predecessor_nodes + input_sharding_specs = [] + for input_node in input_nodes: + dim_partition_dict_for_input = {} + entire_shape = input_node._meta_data.shape + sharding_spec = ShardingSpec(self.device_mesh, + entire_shape, + dim_partition_dict=dim_partition_dict_for_input) + input_sharding_specs.append(sharding_spec) + + dim_partition_dict = {} + output_sharding_spec = input_sharding_specs + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + input_sharding_specs) + + # clear the resharding cost for the output node + # TODO: we may remove this in final version + for prev_node, resharding_cost_list in resharding_costs.items(): + resharding_costs[prev_node] = [0] * len(resharding_cost_list) + + sharding_strategy_attribute = ShardingStrategy(name, + output_sharding_spec, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=tuple(input_sharding_specs)) + strategies_vector.append(sharding_strategy_attribute) + + self.remove_duplicated_strategy(strategies_vector) + setattr(node, 'strategies_vector', strategies_vector) + self.leaf_strategies.append(strategies_vector) + self.strategy_map[node] = strategies_vector + + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ba3b7cd0dace6a3adb8626df990c82502e2b5d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -0,0 +1,31 @@ +from .addmm_handler import ADDMMFunctionHandler +from .batch_norm_handler import BatchNormModuleHandler +from .binary_elementwise_handler import BinaryElementwiseHandler +from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler +from .conv_handler import ConvFunctionHandler, ConvModuleHandler +from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler +from .experimental import PermuteHandler, ViewHandler +from .getatrr_handler import GetattrHandler +from .getitem_handler import GetItemHandler +from .layer_norm_handler import LayerNormModuleHandler +from .linear_handler import LinearFunctionHandler, LinearModuleHandler +from .matmul_handler import MatMulHandler +from .normal_pooling_handler import NormPoolingHandler +from .output_handler import OuputHandler +from .placeholder_handler import PlacehodlerHandler +from .registry import operator_registry +from .reshape_handler import ReshapeHandler +from .softmax_handler import SoftmaxHandler +from .sum_handler import SumHandler +from .tensor_constructor_handler import TensorConstructorHandler +from .unary_elementwise_handler import UnaryElementwiseHandler +from .where_handler import WhereHandler + +__all__ = [ + 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', + 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', + 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', + 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..da0d199c5e05b37340cb4ddcfee0b52a9102fadf --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Union + +import torch + +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator + +__all__ = ['ADDMMFunctionHandler'] + + +@operator_registry.register(torch.addmm) +@operator_registry.register(torch.Tensor.addmm) +class ADDMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ + + def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType: + if isinstance(tensor, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + return data_type + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + + # input operand + input_data = self.node.args[1]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[1]), + type=self._infer_op_data_type(input_data), + data=input_data) + + # other operand + other_data = self.node.args[2]._meta_data + physical_other_operand = OperationData(name=str(self.node.args[2]), + type=self._infer_op_data_type(other_data), + data=other_data) + # bias physical shape + bias_logical_shape = self.node._meta_data.shape + bias_data = self.node.args[0]._meta_data + physical_bias_operand = OperationData(name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape) + + # output + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = { + "input": physical_input_operand, + "other": physical_other_operand, + "output": physical_output, + 'bias': physical_bias_operand + } + + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdd15d1662147a2f0a92965947e7c4f5c92b1fc --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -0,0 +1,69 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import ModuleHandler +from .registry import operator_registry +from .strategy import BatchNormStrategyGenerator, StrategyGenerator + +__all__ = ['BatchNormModuleHandler'] + + +@operator_registry.register(torch.nn.BatchNorm1d) +@operator_registry.register(torch.nn.BatchNorm2d) +@operator_registry.register(torch.nn.BatchNorm3d) +class BatchNormModuleHandler(ModuleHandler): + """ + A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + physical_running_mean_operand = OperationData(name="running_mean", + type=OperationDataType.BUFFER, + data=self.named_buffers['running_mean'], + logical_shape=self.named_buffers['running_mean'].shape) + + physical_running_var_operand = OperationData(name="running_var", + type=OperationDataType.BUFFER, + data=self.named_buffers['running_var'], + logical_shape=self.named_buffers['running_var'].shape) + + physical_num_batches_tracked_operand = OperationData( + name="num_batches_tracked", + type=OperationDataType.BUFFER, + data=self.named_buffers['num_batches_tracked'], + logical_shape=self.named_buffers['num_batches_tracked'].shape) + + mapping = { + "input": physical_input_operand, + "other": physical_other_operand, + "output": physical_output, + "running_mean": physical_running_mean_operand, + "running_var": physical_running_var_operand, + "num_batches_tracked": physical_num_batches_tracked_operand + } + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5b600e735f8d865da11c98c177fe7ff68574e913 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -0,0 +1,102 @@ +from typing import Dict, List, Union + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + ShardingStrategy, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..constants import BCAST_FUNC_OP +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator + +__all__ = ['BinaryElementwiseHandler'] + + +@operator_registry.register(BCAST_FUNC_OP) +class BinaryElementwiseHandler(NodeHandler): + """ + An BinaryBcastOpHandler is a node handler which deals with operations which have two + operands and broadcasting occurs such as torch.add. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + bcast_shape = self.node._meta_data.shape + + def _get_op_data_type(tensor): + if isinstance(tensor, torch.nn.parameter.Parameter): + return OperationDataType.PARAM + else: + return OperationDataType.ARG + + def _get_arg_value(idx): + if isinstance(self.node.args[idx], Node): + meta_data = self.node.args[idx]._meta_data + else: + # this is in fact a real data like int 1 + # but we can deem it as meta data + # as it won't affect the strategy generation + assert isinstance(self.node.args[idx], (int, float)) + meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + return meta_data + + input_meta_data = _get_arg_value(0) + other_meta_data = _get_arg_value(1) + output_meta_data = self.node._meta_data + + input_op_data = OperationData(name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=bcast_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + for op_name, op_data in op_data_mapping.items(): + if not isinstance(op_data.data, torch.Tensor): + # remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2) + strategy.sharding_specs.pop(op_data) + else: + # convert the logical sharding spec to physical sharding spec if broadcast + # e.g. torch.rand(4, 4) + torch.rand(4) + physical_shape = op_data.data.shape + logical_shape = op_data.logical_shape + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + sharding_spec, logical_shape, physical_shape) + + strategy.sharding_specs[op_data] = sharding_spec + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=op_data, + sharding_spec=sharding_spec) + strategy.communication_actions[op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1d958e15ab7cca6ab0dec8b74e26c2e17be0cf --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -0,0 +1,107 @@ +from typing import Dict, List, Union + +import torch + +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator + +__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] + + +def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): + """ + This function is a helper function which extracts the common logic for both `bmm` and `addbmm` + node handler to reduce code redundancy. + """ + # input operand + physical_input_operand = OperationData(name=str(node.args[input_idx]), + type=OperationDataType.ARG, + data=node.args[input_idx]._meta_data) + + # other operand + physical_other_operand = OperationData(name=str(node.args[other_idx]), + type=OperationDataType.ARG, + data=node.args[other_idx]._meta_data) + + # output + physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if bias_idx is not None: + # bias physical shape + bias_logical_shape = node._meta_data.shape + physical_bias_operand = OperationData(name=str(node.args[bias_idx]), + type=OperationDataType.ARG, + data=node.args[bias_idx]._meta_data, + logical_shape=bias_logical_shape) + mapping['bias'] = physical_bias_operand + return mapping + + +@operator_registry.register(torch.bmm) +@operator_registry.register(torch.Tensor.bmm) +class BMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=0, other_idx=1) + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + +@operator_registry.register(torch.addbmm) +@operator_registry.register(torch.Tensor.addbmm) +class AddBMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the addition + batched matrix multiplication operation in PyTorch. + Such operations including `torch.addbmm` and `torch.Tensor.addbmm` require the two matmul tensor to be 3D. However, due to the + addition, logical-physical shape conversion is required for the bias term. + + As the addbmm operation will reduce the batch dimension, the bias is maximum 2D. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=1, other_idx=2, bias_idx=0) + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + if 'bias' in op_data_mapping: + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0c00160effbf7c0ad514d9d33523ee05e2a8564a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -0,0 +1,120 @@ +from typing import Dict, List + +import torch +import torch.nn.functional as F + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import transpose_partition_dim +from .node_handler import ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import ConvStrategyGenerator, StrategyGenerator + +__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] + + +@operator_registry.register(torch.nn.Conv1d) +@operator_registry.register(torch.nn.Conv2d) +@operator_registry.register(torch.nn.Conv3d) +class ConvModuleHandler(ModuleHandler): + """ + A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + logical_shape_for_weight = list(self.named_parameters["weight"].shape) + logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ + 1], logical_shape_for_weight[0] + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=torch.Size(logical_shape_for_weight)) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if "bias" in self.named_parameters: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == "weight": + transpose_partition_dim(sharding_spec, 0, 1) + return strategy + + +@operator_registry.register(F.conv1d) +@operator_registry.register(F.conv2d) +@operator_registry.register(F.conv3d) +class ConvFunctionHandler(NodeHandler): + """ + A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) + logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ + 1], logical_shape_for_weight[0] + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=torch.Size(logical_shape_for_weight)) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: + # check if the other operand is a parameter + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), + type=data_type, + data=self.node.kwargs["bias"]._meta_data) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == str(self.node.args[1]): + transpose_partition_dim(sharding_spec, 0, 1) + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e154105b672de30b5675ca56147fb6b68205a469 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py @@ -0,0 +1,230 @@ +from typing import Dict, List, Union + +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingNotDivisibleError + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import EmbeddingStrategyGenerator, StrategyGenerator + +__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] + + +def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, + output_name: str) -> List[ShardingStrategy]: + """ + This function converts the logical sharding spec to the physical sharding spec for both the input and output + of the embedding operation. + + Args: + strategy (ShardingStrategy): the logical strategy generated by the strategy generator. + input_name (str): the name of the OperationData object for the input. + output_name (str): the name of the OperationData object for the output. + """ + # the result will be a list of strategies + sharding_strategies = [] + + # get operation data + input_op_data = strategy.get_op_data_by_name(input_name) + output_op_data = strategy.get_op_data_by_name(output_name) + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + # get logger for debug message + logger = get_dist_logger() + + # For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for + # logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the + # physical input shape. Thus, we enumerate to get all possible cases. + if input_sharding_spec.dim_partition_dict: + # if bool(input_sharding_spec.dim_partition_dict), it means that the + # the generated sharding strategy does shard the non-matrix dimension, + # in this case, we need to do enumeration + num_input_dims = input_op_data.data.dim() + for i in range(num_input_dims): + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + try: + # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} + else: + dim_mapping = {0: i} + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + + strategy_copy.name = f'{strategy.name}_{i}' + sharding_strategies.append(strategy_copy) + + except ShardingNotDivisibleError as e: + logger.debug( + f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + ) + else: + # the generated sharding strategy does not shard the non-matrix dimension, + # in this case, we don't need to do enumeration + # but instead, we still need to convert the logical shape to physical shape + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + + # after updating, the logical shape will be replaced by the physical shape + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={}, + physical_shape=input_op_data.data.shape, + inplace=True) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + dim_mapping = {} + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(strategy_copy) + + return sharding_strategies + + +@operator_registry.register(torch.nn.Embedding) +class EmbeddingModuleHandler(ModuleHandler): + """ + A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension, + # and then the sharding spec will be generated based on the logical 1D tensor. + # After that, the logical sharding info will be enumerated among all the physical dimensions. + # Finally, the input will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape) + + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight']) + + # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as + # (batch dimension, embedding dimension), and then the sharding spec will be generated based + # on the logical 2D tensor. + # After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions. + # Finally, the output will be transformed back to its original shape in self.post_process + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + return mapping + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + """ + Convert the sharding spec from the logical shape to the physical shape. + """ + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, + input_name=str( + self.node.args[0]), + output_name=str(self.node)) + return strategies + + +@operator_registry.register(F.embedding) +class EmbeddingFunctionHandler(NodeHandler): + """ + A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # In F.embedding operation, all the dimensions of input will be treated as the batch dimension, + # and then the sharding spec will be generated based on the logical 1D tensor. + # After that, the logical sharding info will be enumerated among all the physical dimensions. + # Finally, the input will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data) + + # Same as input, in F.embedding operation, all the dimensions of output will be treated as + # (batch dimension, embedding dimension), and then the sharding spec will be generated based + # on the logical 2D tensor. + # After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions. + # Finally, the output will be transformed back to its original shape in self.post_process + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec from the logical shape to the physical shape. + """ + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, + input_name=str( + self.node.args[0]), + output_name=str(self.node)) + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15f66104b15680665f8161a7e343c89992447a5c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py @@ -0,0 +1,10 @@ +from .permute_handler import PermuteHandler +from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator +from .split_handler import SplitHandler +from .transpose_handler import TransposeHandler +from .view_handler import ViewHandler + +__all__ = [ + 'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator', + 'SplitHandler', 'SplitGenerator' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d625e153f61f2445b84fc268df7bb38f488ac17 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py @@ -0,0 +1,76 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import PermuteGenerator + +__all__ = ['PermuteHandler'] + + +@operator_registry.register(torch.Tensor.permute) +@operator_registry.register(torch.permute) +class PermuteHandler(NodeHandler): + """ + A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + permute_dims = [] + if self.node.op == 'call_method': + # torch.Tensor.permute (input, *dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + permute_dims.append(arg._meta_data) + else: + assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + permute_dims.append(arg) + else: + # torch.permute (input, dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, (tuple, list)): + permute_dims.extend(arg._meta_data) + else: + assert isinstance( + arg, + (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + permute_dims.extend(arg) + + num_dims = self.node._meta_data.dim() + for i in range(num_dims): + # recover negative value to positive + if permute_dims[i] < 0: + permute_dims[i] += num_dims + + physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "permute_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b7248d011950dbfdb46b1e52305df871fc9898f5 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py @@ -0,0 +1,299 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] + + +class ReshapeGenerator(FollowingStrategyGenerator): + """ + ReshapeGenerator is the base class for all the reshape operation. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + return super().collate_strategies() + + +class ViewGenerator(ReshapeGenerator): + """ + ViewGenerator deals with the sharding strategies of view op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + origin_shape = self.op_data['input'].data.shape + tgt_shape = self.op_data['tgt_shape'].data + + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) + + if keep_sharding_status: + dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, + reshape_mapping_dict) + else: + dim_partition_dict_for_output = {} + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + if keep_sharding_status: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + else: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + # add comm action for converting input to fully replicated + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + # the total mesh dim list only has one element, so the shard dim has only one element as well. + shard_dim = list(dim_partition_dict_for_input.keys())[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = shard_dim + # it will split the input activation grad through shard_dim during backward phase. + input_comm_action.comm_spec.shard_dim = shard_dim + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class PermuteGenerator(ReshapeGenerator): + """ + PermuteGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + permute_dims = self.op_data['permute_dims'].data + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + for dim_index, permute_dim in enumerate(permute_dims): + if permute_dim in dim_partition_dict_for_input: + dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim] + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class TransposeGenerator(ReshapeGenerator): + """ + TransposeGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + + transpose_dims = self.op_data['transpose_dims'].data + dim_0 = transpose_dims[0] + dim_1 = transpose_dims[1] + for dim, sharded_dims in dim_partition_dict_for_input.items(): + if dim == dim_0: + dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0] + elif dim == dim_1: + dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1] + else: + dim_partition_dict_for_output[dim] = sharded_dims + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class SplitGenerator(ReshapeGenerator): + """ + SplitGenerator deals with the sharding strategies of split op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + recover_dims = None + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + split_size, split_dim = self.op_data['split_info'].data + + if split_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(split_dim) + + dim_partition_dict_for_output = [ + copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data)) + ] + assert len(dim_partition_dict_for_output) >= 2 + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}' + + # add comm action if the input need to be recovered to replica in the split dimension. + if recover_dims: + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(recover_dims) == 1: + recover_dims = recover_dims[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=recover_dims, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = split_dim + # it will split the input activation grad through split_dim during backward phase. + input_comm_action.comm_spec.shard_dim = split_dim + + elif len(recover_dims) >= 2: + # original sharding spec + source_spec = input_sharding_spec + # target sharding spec + target_spec = sharding_spec_mapping["input"] + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..38c5eed7d00ed22c220920acd811b2927d7dec94 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py @@ -0,0 +1,63 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import SplitGenerator + +__all__ = ['SplitHandler'] + + +@operator_registry.register(torch.Tensor.split) +@operator_registry.register(torch.split) +class SplitHandler(NodeHandler): + """ + A SplitHandler which deals with the sharding strategies for torch.permute or torch.split. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + split_size = self.node.args[1] + if len(self.node.args) == 3: + # (input, split_size, split_dim) + split_dim = self.node.args[2] + else: + if self.node.kwargs: + split_dim = self.node.kwargs['dim'] + else: + split_dim = 0 + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if split_dim < 0: + split_dim += num_dims + + split_info = (split_size, split_dim) + physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "split_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7336a931677020858c9bfe1dadc5160190c49a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py @@ -0,0 +1,65 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import TransposeGenerator + +__all__ = ['TransposeHandler'] + + +@operator_registry.register(torch.Tensor.transpose) +@operator_registry.register(torch.transpose) +class TransposeHandler(NodeHandler): + """ + A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + transpose_dims = [] + # torch.transpose (input, dim0, dim1) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + transpose_dims.append(arg._meta_data) + else: + transpose_dims.append(arg) + + num_dims = self.node._meta_data.dim() + for i in range(2): + # recover negative value to positive + if transpose_dims[i] < 0: + transpose_dims[i] += num_dims + + physical_shape_operand = OperationData(name='transpose_dims', + type=OperationDataType.ARG, + data=list(transpose_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "transpose_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6be634593510731fadadf612e17491048a40b7c2 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py @@ -0,0 +1,53 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import ViewGenerator + +__all__ = ['ViewHandler'] + + +@operator_registry.register(torch.Tensor.reshape) +@operator_registry.register(torch.reshape) +@operator_registry.register(torch.Tensor.view) +class ViewHandler(NodeHandler): + """ + A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + target_shape = self.node._meta_data.shape + physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "tgt_shape": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..53addb873d1d1a014352058f8ec127f6bf7c4d91 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py @@ -0,0 +1,34 @@ +from typing import Dict, List + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .strategy import GetattrGenerator, StrategyGenerator + +__all__ = ['GetattrHandler'] + + +class GetattrHandler(NodeHandler): + """ + A GetattrHandler which deals with the sharding strategies for Getattr Node. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(GetattrGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # There are only two possible types for get_attr node: + # 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers) + # 2. torch.nn.Module + # temporarily, we just support first case in Tracer, so we don't have to worry about + # issue related to the node._meta_data type. + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..25baa77666b1ef2aec75041836ec20742be406a7 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -0,0 +1,41 @@ +import operator +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator) + +__all__ = ['GetItemHandler'] + + +@operator_registry.register(operator.getitem) +class GetItemHandler(NodeHandler): + """ + A GetItemHandler which deals with the sharding strategies for operator.getitem. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + if isinstance(op_data_mapping["input"].data, torch.Tensor): + generators.append(TensorStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + else: + generators.append(TensorTupleStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "index": physical_other_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..132ac30daed8561b5d7349276253b7eb6d4be4c3 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py @@ -0,0 +1,44 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import ModuleHandler +from .registry import operator_registry +from .strategy import LayerNormGenerator, StrategyGenerator + +__all__ = ['LayerNormModuleHandler'] + + +@operator_registry.register(torch.nn.LayerNorm) +class LayerNormModuleHandler(ModuleHandler): + """ + A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(LayerNormGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e3ce6a520bcb8f96ae04363b38fe72a75fc070 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -0,0 +1,268 @@ +from typing import Dict, List, Union + +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingNotDivisibleError + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator + +__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] + + +def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, + weight_name: str) -> ShardingStrategy: + """ + This function is a helper function used by both module node handler and function node handler. This function will + convert the sharding spec for the transposed weight to the correct partititon spec. + + Args: + strategy (ShardingStrategy): the strategy generated by the strategy generator. + weight_name (str): the name of the OperationData object for the weight. + """ + # switch the dimensions of the transposed weight + sharding_spec = strategy.get_sharding_spec_by_name(weight_name) + op_data = strategy.get_op_data_by_name(weight_name) + assert op_data.logical_shape[0] == op_data.data.shape[1] and \ + op_data.logical_shape[1] == op_data.data.shape[0], \ + "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" + dim_size = len(op_data.logical_shape) + transpose_partition_dim(sharding_spec, 0, dim_size - 1) + return strategy + + +def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, + output_name: str) -> List[ShardingStrategy]: + """ + This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output + should have the same sharding spec. + + Args: + strategy (ShardingStrategy): the logical strategy generated by the strategy generator. + input_name (str): the name of the OperationData object for the input. + output_name (str): the name of the OperationData object for the output. + + + """ + # the result will be a list of strategies + sharding_strategies = [] + + # get operation data + input_op_data = strategy.get_op_data_by_name(input_name) + output_op_data = strategy.get_op_data_by_name(output_name) + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_input_dims = len(input_op_data.logical_shape) - 1 + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_input_dims = input_op_data.data.dim() - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + if last_logical_input_dims in input_sharding_spec.dim_partition_dict: + input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims} + else: + input_last_dim_mapping = {} + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + output_last_dim_mapping = {} + + # get logger for debug message + logger = get_dist_logger() + + # for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only + # 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension. + # the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape. + # Thus, we enumerate to get all possible cases. + if 0 in input_sharding_spec.dim_partition_dict: + # if 0 is in the dim_partition_dict, it means that the + # the generated sharding strategy does shard the non-matrix dimension, + # in this case, we need to do enumeration + num_input_dims = input_op_data.data.dim() + for i in range(num_input_dims - 1): + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + try: + # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + input_dim_mapping = {0: i} + input_dim_mapping.update(input_last_dim_mapping) + + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True) + output_dim_mapping = {0: i} + output_dim_mapping.update(output_last_dim_mapping) + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + strategy_copy.name = f'{strategy.name}_{i}' + sharding_strategies.append(strategy_copy) + except ShardingNotDivisibleError as e: + logger.debug( + f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + ) + else: + # the generated sharding strategy does not shard the non-matrix dimension, + # in this case, we don't need to do enumeration + # but instead, we still need to convert the logical shape to physical shape + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + + # after updating, the logical shape will be replaced by the physical shape + input_dim_mapping = {} + input_dim_mapping.update(input_last_dim_mapping) + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True) + + output_dim_mapping = {} + output_dim_mapping.update(output_last_dim_mapping) + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(strategy_copy) + return sharding_strategies + + +@operator_registry.register(torch.nn.Linear) +class LinearModuleHandler(ModuleHandler): + """ + A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape[::-1]) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if 'bias' in self.named_parameters is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + """ + Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed: + 1. the sharding spec is updated for the transposed weight + 2. the input and output sharding specs are updated to physical shape. + """ + # switch the dimensions of the transposed weight + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') + + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at dim 0 to one of the first few dimensions of the input + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, + input_name=str(self.node.args[0]), + output_name=str(self.node)) + return strategies + + +@operator_registry.register(F.linear) +class LinearFunctionHandler(NodeHandler): + """ + A LinearFunctionHandler which deals with the sharding strategies for F.Linear. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=self.node.args[1]._meta_data.shape[::-1]) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: + # check if the other operand is a parameter + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), + type=data_type, + data=self.node.kwargs["bias"]._meta_data) + mapping['bias'] = physical_bias_operand + + return mapping + + def post_process(self, strategy: ShardingStrategy): + # switch the dimensions of the transposed weight + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, + weight_name=str(self.node.args[1])) + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at dim 0 to one of the first few dimensions of the input + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, + input_name=str(self.node.args[0]), + output_name=str(self.node)) + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f9fd01d891979c0a8e7664b819176e8caac1d4 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -0,0 +1,486 @@ +import operator +from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum +from functools import reduce +from typing import Dict, List, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.utils.broadcast import ( + BroadcastType, + get_broadcast_dim_info, + get_broadcast_shape, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, + StrategyGenerator, +) + + +class MatMulType(Enum): + """ + The MatMulType is categorized into 4 types based on the reference of torch.matmul + in https://pytorch.org/docs/stable/generated/torch.matmul.html. + + DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements + MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D + MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D + BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D + """ + DOT = 0 + MM = 1 + MV = 2 + BMM = 3 + + +def get_matmul_type(input_dim: int, other_dim: int): + """ + Determine which type of matmul operation should be executed for the given tensor dimensions. + + Args: + input_dim (int): the number of dimensions for the input tenosr + other_dim (int): the number of dimensions for the other tenosr + """ + if input_dim == 1 and other_dim == 1: + matmul_type = MatMulType.DOT + elif input_dim in [1, 2] and other_dim == 2: + matmul_type = MatMulType.MM + elif input_dim == 2 and other_dim == 1: + matmul_type = MatMulType.MV + elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2): + matmul_type = MatMulType.BMM + else: + raise ValueError( + f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation" + ) + return matmul_type + + +class BmmTransform(ABC): + """ + BmmTransform is an abstraction of the shape conversion between logical and physical operation data + during the strategy generation. + """ + + @abstractmethod + def apply(self, shape_mapping: Dict[str, List[int]]): + pass + + @abstractmethod + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + pass + + +class Padder(BmmTransform): + """ + Add padding to the matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + # keep the padding dim, op_name -> padded_dim + self.padded_dim_mapping = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = deepcopy(shape_mapping) + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + if len(input_shape) == 1: + # if the input is a 1D tensor, 1 is prepended to its shape + # and it will be removed afterwards + input_shape.insert(0, 1) + self.padded_dim_mapping['input'] = -2 + self.padded_dim_mapping['output'] = -2 + elif len(other_shape) == 1: + # if the other is a 1D tensor, 1 is appended to its shape + # and it will be removed afterwards + other_shape = other_shape.append(1) + self.padded_dim_mapping['other'] = -1 + self.padded_dim_mapping['output'] = -1 + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + input_op_data = op_data_mapping['input'] + other_op_data = op_data_mapping['other'] + + def _remove_padded_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + dim_partition_list = [None] * len(tensor_shape) + + # padded dim is a negative number as the padded dim must be a matrix dim + padded_dim = self.padded_dim_mapping[key] + + # compute the new dim partition + for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items(): + dim_partition_list[tensor_dim] = mesh_dims + dim_partition_list.pop(padded_dim) + unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None} + + # compute unpadded tensor shape + tensor_shape.pop(padded_dim) + + assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + + # update sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + + # only one of input and other will be padded + if 'input' in self.padded_dim_mapping: + _remove_padded_dim('input', strategy_copy) + _remove_padded_dim('output', strategy_copy) + elif 'other' in self.padded_dim_mapping: + _remove_padded_dim('other', strategy_copy) + _remove_padded_dim('output', strategy_copy) + + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Broadcaster(BmmTransform): + """ + Broadcast the non-matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + self.broadcast_dim_info = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + + # get shapes + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + # sanity check + assert len(input_shape) > 1 and len(other_shape) > 1 + + # broadcast the batch dim and record + bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2]) + + # store the broadcast dim info + input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) + other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) + self.broadcast_dim_info['input'] = input_broadcast_dim_info + self.broadcast_dim_info['other'] = other_broadcast_dim_info + + # create the full logical shape + input_shape = bcast_non_matrix_dims + input_shape[-2:] + other_shape = bcast_non_matrix_dims + other_shape[-2:] + assert len(input_shape) == len(other_shape) + + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # remove sharding on the broadcast dim + def _remove_sharding_on_broadcast_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + + for dim_idx, broadcast_type in self.broadcast_dim_info[key].items(): + if broadcast_type == BroadcastType.MULTIPLE: + # if the dim is originally 1 and multiplied during broadcast + # we set its sharding to R + # e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8] + # the dim 0 of [1, 2, 4] is multiplied to 4 + tensor_shape[dim_idx] = 1 + elif broadcast_type == BroadcastType.PADDDING: + # if the dim is padded + # we remove its sharding + tensor_shape[dim_idx] = None + + tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None] + + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec=sharding_spec, + logical_shape=sharding_spec.entire_shape, + physical_shape=tensor_shape_before_broadcast) + strategy.sharding_specs[op_data] = physical_sharding_spec + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + _remove_sharding_on_broadcast_dim('input', strategy_copy) + _remove_sharding_on_broadcast_dim('other', strategy_copy) + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Viewer(BmmTransform): + """ + Change the shape of the tensor from N-D to 3D + """ + + def __init__(self) -> None: + self.batch_dims_before_view = None + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + + # get shapes + input_shape = shape_mapping['input'] + other_shape = shape_mapping['other'] + + # view to 3d tensor + assert len(input_shape) >= 3 and len(other_shape) >= 3 + input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] + other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] + output_shape = input_shape[:2] + other_shape[2:] + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + mapping_copy['output'] = output_shape + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # get operation data + def _update_sharding_spec(key, strategy, physical_batch_dim): + """ + Map the logical batch dim to the physical batch dim + """ + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + dim_partition_dict = sharding_spec.dim_partition_dict + entire_shape = sharding_spec.entire_shape + + # upddate the dimension index for the matrix dimensions + if 2 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2) + if 1 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1) + + # map the logical batch dim to phyiscal batch dim + if 0 in dim_partition_dict: + batch_dim_shard = dim_partition_dict.pop(0) + dim_partition_dict[physical_batch_dim] = batch_dim_shard + + # the new shape will be the batch dims + the last 2 matrix dims + shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:]) + sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict) + + num_batch_dim_before_view = len(self.batch_dims_before_view) + + # enumerate all sharding strategies + strategies = [] + for i in range(num_batch_dim_before_view): + # create a new strategy + strategy_copy = strategy.clone() + try: + _update_sharding_spec('input', strategy_copy, i) + _update_sharding_spec('other', strategy_copy, i) + _update_sharding_spec('output', strategy_copy, i) + strategies.append(strategy_copy) + except ShardingSpecException as e: + continue + return strategies + + +def _get_bmm_logical_shape(input_shape, other_shape, transforms): + """ + Compute the logical shapes for BMM operation. BMM has a general representation + [b, i, k] = [b, i, j] x [b, j, k] + + The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions + The logical shape for the bmm operands will undergo three stages + 1. append/prepend the 1 to the 1D tensor if there is any + 2. broadcast the non-matrix dimensions + 3. reshape to 3 dimensions + + """ + shape_mapping = {'input': input_shape, 'other': other_shape} + + for transform in transforms: + shape_mapping = transform.apply(shape_mapping) + + input_shape = shape_mapping.get('input', None) + other_shape = shape_mapping.get('other', None) + output_shape = shape_mapping.get('output', None) + + return input_shape, other_shape, output_shape + + +@operator_registry.register(torch.matmul) +@operator_registry.register(torch.Tensor.matmul) +class MatMulHandler(NodeHandler): + """ + The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation. + According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on + the operands. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # check which type of operation this matmul will call + self.input_meta_data = self.node.args[0]._meta_data + self.other_meta_data = self.node.args[1]._meta_data + self.output_meta_data = self.node._meta_data + + input_dim = self.input_meta_data.dim() + other_dim = self.other_meta_data.dim() + self.matmul_type = get_matmul_type(input_dim, other_dim) + + if self.matmul_type == MatMulType.BMM: + # bmm operation can possibly involve padding, broadcasting and view + # these transforms will be used to create logical shape and + # recover physical sharding spec + self.transforms = [Padder(), Broadcaster(), Viewer()] + else: + self.transforms = None + + def get_strategy_generator(self) -> List[StrategyGenerator]: + generators = [] + op_data_mapping = self.get_operation_data_mapping() + if self.matmul_type == MatMulType.BMM: + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.DOT: + generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MV: + generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MM: + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + logical_shape_func = { + MatMulType.DOT: self._get_logical_shape_for_dot, + MatMulType.MM: self._get_logical_shape_for_mm, + MatMulType.MV: self._get_logical_shape_for_mv, + MatMulType.BMM: self._get_logical_shape_for_bmm + } + logical_shapes = logical_shape_func[self.matmul_type]() + op_data_mapping = self._get_op_data_mapping(*logical_shapes) + return op_data_mapping + + def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape): + # convert list to torch.Size + if input_logical_shape: + input_logical_shape = torch.Size(input_logical_shape) + + if other_logical_shape: + other_logical_shape = torch.Size(other_logical_shape) + + if output_logical_shape: + output_logical_shape = torch.Size(output_logical_shape) + + # create op data + input_op_data = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def _get_logical_shape_for_dot(self): + """ + The operands for the dot operation have the same logical shape as the physical shape + """ + return None, None, None + + def _get_logical_shape_for_mm(self): + """ + We need to handle the input tensor for a matrix-matrix multiplcation as the input + tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape + (e.g. [4] -> [1, 4]). + """ + if self.input_meta_data.dim() == 1: + input_logical_shape = [1] + list(self.input_meta_data.shape) + input_logical_shape = torch.Size(input_logical_shape) + else: + input_logical_shape = None + return input_logical_shape, None, None + + def _get_logical_shape_for_mv(self): + """ + No broadcasting or dim insertion occurs for matrix-vector operation. + """ + return None, None, None + + def _get_logical_shape_for_bmm(self): + input_physical_shape = list(self.input_meta_data.shape) + other_physical_shape = list(self.other_meta_data.shape) + return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms) + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + if self.matmul_type in [MatMulType.DOT, MatMulType.MV]: + return strategy + elif self.matmul_type == MatMulType.MM: + if self.input_meta_data.dim() == 1: + # if a 1 is prepended to the input shape (this occurs when input is a 1D tensor) + # we need to remove that dim + input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0])) + input_physical_shape = self.node.args[0]._meta_data.shape + dim_partition_dict = input_sharding_spec.dim_partition_dict + + # remove the partitioning in the dim 0 + if 0 in dim_partition_dict: + dim_partition_dict.pop(0, None) + + # move the partitioning in dim 1 to dim 0 + if -1 in dim_partition_dict: + shard = dim_partition_dict.pop(-1) + dim_partition_dict[0] = shard + if 1 in dim_partition_dict: + shard = dim_partition_dict.pop(1) + dim_partition_dict[0] = shard + + # re-init the sharding spec + input_sharding_spec.__init__(input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict) + return strategy + else: + return strategy + elif self.matmul_type == MatMulType.BMM: + op_data_mapping = self.get_operation_data_mapping() + + strategies = [strategy] + # recover the physical sharding spec + for transform in self.transforms[::-1]: + recovered_stragies = [] + for strategy_ in strategies: + output = transform.recover(op_data_mapping, strategy_) + if isinstance(output, ShardingStrategy): + recovered_stragies.append(output) + elif isinstance(output, (list, tuple)): + recovered_stragies.extend(output) + else: + raise TypeError( + f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + strategies = recovered_stragies + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..27957ca63126c57adb6385d335bb54eff563b43d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -0,0 +1,223 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingSpec, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager + +from .strategy import StrategyGenerator + + +class NodeHandler(ABC): + ''' + The NodeHandler is an abstract class used to generate every possible strategies for an operator node. + + Args: + node (Node): the input node in node argument list. + device_mesh (DeviceMesh): A logical view of a physical mesh. + strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + ''' + + def __init__( + self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + ) -> None: + self.node = node + self.predecessor_node = list(node._input_nodes.keys()) + self.successor_node = list(node.users.keys()) + self.device_mesh = device_mesh + self.strategies_vector = strategies_vector + + def update_resharding_cost(self, strategy: ShardingStrategy) -> None: + """ + Compute the resharding costs and save the costs in the ShardingStrategy object. + """ + # TODO: test this function when other handlers are ready + resharding_costs = {} + shape_consistency_manager = ShapeConsistencyManager() + + for node in self.predecessor_node: + node_name = str(node) + # get the current sharding spec generated by this node handler + + # we will not compute the resharding costs for the node not counted in the strategy. + # And the node with tuple or list output need to be handled below. + node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()] + if str(node) not in node_in_strategy: + continue + + op_data = strategy.get_op_data_by_name(node_name) + current_sharding_spec = strategy.sharding_specs[op_data] + # get the sharding specs for this node generated + # in its own node handler + assert hasattr(node, 'strategies_vector'), \ + f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' + prev_strategy_vector = node.strategies_vector + prev_sharding_specs = [ + prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector + ] + + # create data structrure to store costs + if node not in resharding_costs: + resharding_costs[node] = [] + + def _compute_resharding_cost( + prev_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + """ + This is a helper function to compute the resharding cost for a specific strategy of a node. + """ + if prev_sharding_spec is None: + return TrainCycleItem(fwd=0, bwd=0, total=0) + elif isinstance(prev_sharding_spec, ShardingSpec): + if isinstance(data, torch.nn.parameter.Parameter): + # we won't compute the resharding cost for the parameters, + # since the parameters will be sharded before runtime and + # not converted during runtime. + return TrainCycleItem(fwd=0, bwd=0, total=0) + elif isinstance(data, torch.Tensor): + dtype = data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + _, _, consistency_cost = shape_consistency_manager.shape_consistency( + prev_sharding_spec, current_sharding_spec) + + resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes) + return resharding_cost + else: + # This raise is used to check if we have missed any type of data. + # It could be merged into Parameter branch, which means we won't handle + # non-tensor arguments. + raise ValueError(f'Unsupported data type {type(data)}') + else: + assert isinstance(prev_sharding_spec, (tuple, list)), \ + f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + + fwd_cost = 0 + bwd_cost = 0 + total_cost = 0 + for index, (prev_sharding_spec_item, + current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, + current_sharding_spec)): + item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, + data[index]) + fwd_cost += item_cost.fwd + bwd_cost += item_cost.bwd + total_cost += item_cost.total + resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost) + return resharding_cost + + # for each sharding spec generated by the predecessor's node handler + # compute the resharding cost to switch to the sharding spec generated + # by the current node handler + for prev_sharding_spec in prev_sharding_specs: + resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data) + resharding_costs[node].append(resharding_cost) + strategy.resharding_costs = resharding_costs + return strategy + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + Register different sharding strategies for the current node. + """ + strategy_generators = self.get_strategy_generator() + for generator in strategy_generators: + strategies = generator.generate() + + # postprocess a strategy + # postprocess can produce one strategy or multiple strategies + post_processed_strategies_map = map(self.post_process, strategies) + post_processed_strategies = [] + + for strategy in post_processed_strategies_map: + if isinstance(strategy, (list, tuple)): + post_processed_strategies.extend(strategy) + else: + post_processed_strategies.append(strategy) + + # compute the resharding costs based on the previous node + # strategies if specified + if compute_resharding_cost: + updated_strategies = map(self.update_resharding_cost, post_processed_strategies) + post_processed_strategies = list(updated_strategies) + + self.strategies_vector.extend(post_processed_strategies) + + # validating the correctness of the sharding strategy + for strategy in self.strategies_vector: + for op_data, sharding_spec in strategy.sharding_specs.items(): + if op_data.data is not None and isinstance(op_data.data, torch.Tensor): + check_sharding_spec_validity(sharding_spec, op_data.data) + + return self.strategies_vector + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # tranform the strategy generated + # e.g. to process the sharding strategy for the transposed weights + return strategy + + @abstractmethod + def get_strategy_generator(self) -> List[StrategyGenerator]: + """ + Define which generators should be used by this NodeHandler object. + """ + pass + + @abstractmethod + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + """ + Returns the mapping between the logical operation data to its physical data. + A logical operation data is a data associated with an operation, which can be input and output. It is + defined by the strategy generator, for example, a matrix multiplication operation has two operands "input" + and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is + the module input, the physical operand for "other" is the module weight, and the physical result for "output" + is the module output. + Note that the operand name is specified by the StrategyGenerator object. + + For example: + + # for a linear layer + mapping = { + "input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data), + "other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']), + "bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']), + "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), + } + """ + pass + + +class ModuleHandler(NodeHandler): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # set attributes to access module parameters for convenience + assert self.node.graph.owning_module is not None, \ + f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' + module = self.node.graph.owning_module.get_submodule(self.node.target) + named_parameters = list(module.named_parameters(recurse=False)) + named_buffers = list(module.named_buffers(recurse=False)) + # convert named parameters from list to dict + named_parameters = {k: v for k, v in named_parameters} + named_buffers = {k: v for k, v in named_buffers} + self.module = module + self.named_parameters = named_parameters + self.named_buffers = named_buffers diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..1509c05a351263fc61dd66d824eee85fb8362640 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -0,0 +1,41 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import ModuleHandler +from .registry import operator_registry +from .strategy import NormalPoolStrategyGenerator, StrategyGenerator + +__all__ = ['NormPoolingHandler'] + + +@operator_registry.register(torch.nn.MaxPool1d) +@operator_registry.register(torch.nn.MaxPool2d) +@operator_registry.register(torch.nn.MaxPool1d) +@operator_registry.register(torch.nn.AvgPool1d) +@operator_registry.register(torch.nn.AvgPool2d) +@operator_registry.register(torch.nn.AvgPool3d) +class NormPoolingHandler(ModuleHandler): + """ + A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_weight_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d2edfa83c37dbf76062f252ab18603975894aa26 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -0,0 +1,52 @@ +from typing import Dict, List + +import torch + +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import NodeHandler +from .strategy import OutputGenerator, StrategyGenerator + +__all__ = ['OuputHandler'] + + +class OuputHandler(NodeHandler): + """ + A OuputHandler which deals with the sharding strategies for Output Node. + """ + + def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + output_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.output_option = output_option + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + mapping = {} + output_meta_data = [] + for index, input_node in enumerate(self.predecessor_node): + input_meta_data = input_node._meta_data + physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) + name_key = f'input_{index}' + mapping[name_key] = physical_inputs + output_meta_data.append(input_meta_data) + + assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + if len(output_meta_data) == 1: + output_meta_data = output_meta_data[0] + else: + output_meta_data = tuple(output_meta_data) + + self.node._meta_data = output_meta_data + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping["output"] = physical_output + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c72a5d3bfa9bea4947b0d38b96a989938f37e0b9 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -0,0 +1,38 @@ +from typing import Dict, List + +from torch.fx.node import Node + +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import NodeHandler +from .strategy import PlaceholderGenerator, StrategyGenerator + +__all__ = ['PlacehodlerHandler'] + + +class PlacehodlerHandler(NodeHandler): + """ + A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. + """ + + def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + placeholder_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.placeholder_option = placeholder_option + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8e06cec4f463a8600b2abe1a7f6713ec2ffb2931 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -0,0 +1,30 @@ +class Registry: + # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store, f'{source} not found in the {self.name} registry' + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +operator_registry = Registry('operator') diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b463487165cbaf44ced0111d7cbeb3257982002a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -0,0 +1,71 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import ReshapeGenerator, StrategyGenerator + +__all__ = ['ReshapeHandler'] + + +@operator_registry.register(torch.flatten) +@operator_registry.register(torch.Tensor.unsqueeze) +@operator_registry.register(torch.nn.AdaptiveAvgPool2d) +class ReshapeHandler(NodeHandler): + """ + A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def infer_logical_shape(self, data): + """ + This function is used to infer logical shape for operands. + + Notes: This function is only used for the operands whose data are not only in type of tensor, + such as tuple of tensor. + """ + if isinstance(data, torch.Tensor): + return data.shape + else: + assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor." + logical_shape = [] + for tensor in data: + assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor." + logical_shape.append(tensor.shape) + logical_shape = tuple(logical_shape) + return logical_shape + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + input_logical_shape = self.infer_logical_shape(input_data) + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=data_type, + data=input_data, + logical_shape=input_logical_shape) + + output_data = self.node._meta_data + output_logical_shape = self.infer_logical_shape(output_data) + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..743a1f90eaafa869b3a62882648cbde53f9e3166 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import SoftmaxGenerator, StrategyGenerator + +__all__ = ['SoftmaxHandler'] + + +@operator_registry.register(torch.nn.Softmax) +@operator_registry.register(torch.nn.functional.softmax) +class SoftmaxHandler(NodeHandler): + """ + A SoftmaxHandler which deals with the sharding strategies for + torch.nn.Softmax or torch.nn.functional.softmax. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + softmax_dim = self.node.kwargs['dim'] + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if softmax_dim < 0: + softmax_dim += num_dims + + physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "softmax_dim": physical_dim_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d25475f9c57f756148a51baa3060680343af5c3 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -0,0 +1,32 @@ +from .batch_norm_generator import BatchNormStrategyGenerator +from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator +from .conv_strategy_generator import ConvStrategyGenerator +from .embedding_generator import EmbeddingStrategyGenerator +from .getattr_generator import GetattrGenerator +from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator +from .layer_norm_generator import LayerNormGenerator +from .matmul_strategy_generator import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, +) +from .normal_pooling_generator import NormalPoolStrategyGenerator +from .output_generator import OutputGenerator +from .placeholder_generator import PlaceholderGenerator +from .reshape_generator import ReshapeGenerator +from .softmax_generator import SoftmaxGenerator +from .strategy_generator import StrategyGenerator +from .sum_generator import SumGenerator +from .tensor_constructor_generator import TensorConstructorGenerator +from .unary_elementwise_generator import UnaryElementwiseGenerator +from .where_generator import WhereGenerator + +__all__ = [ + 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', + 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', + 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', + 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', + 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', + 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3812429fc274064163f5859d0fedb04f8115fb --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -0,0 +1,350 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + +__all__ = ['BatchNormStrategyGenerator'] + + +class BatchNormStrategyGenerator(StrategyGenerator): + """ + A StrategyGenerator which deals with the sharding strategies of batch normalization. + + To keep the math consistency, there are two way to do BatchNorm if the input + shards on batch dimension: + 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. + 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help + us to keep the computing correctness. + In this generator, both methods will be considered. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For BatchNorm1d, the dim of input data should be 3([N, C, L]). + For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). + For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: a constant coefficient need to be added. + # 1D: (L) * N * Cin + # 2D: (H * W) * N * Cin + # 3D: (H * W * D) * N * Cin + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_output_shape) + input_product = reduce(operator.mul, sharded_input_shape, 1) + forward_compute_cost = input_product + backward_activation_compute_cost = input_product + backward_weight_compute_cost = input_product + backward_compute_cost = backward_weight_compute_cost + backward_activation_compute_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output"), + 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"), + 'running_var': self._compute_size_in_bytes(strategy, "running_var"), + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum( + [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum( + [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost, + buffer=fwd_buffer_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def split_input_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0] + }, + "output": { + 1: [mesh_dim_0] + }, + "running_mean": { + 0: [mesh_dim_0] + }, + "running_var": { + 0: [mesh_dim_0] + }, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1] + }, + "running_mean": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "running_var": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x R' + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + "running_mean": {}, + "running_var": {}, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0] + }, + "running_mean": {}, + "running_var": {}, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.IMPLICIT) + + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "running_mean": {}, + "running_var": {}, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for gradients of weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.IMPLICIT) + + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "other": { + 0: [mesh_dim_1], + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "running_mean": { + 0: [mesh_dim_1], + }, + "running_var": { + 0: [mesh_dim_1], + }, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for gradients of weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0], + comm_type=CommType.IMPLICIT) + + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + ''' + + strategy_list = [] + # RS = RS x S + strategy_list.append(self.split_input_channel(0)) + strategy_list.append(self.split_input_channel(1)) + + # RR = RR x R + strategy_list.append(self.non_split()) + + # RS01 = RS01 x S01 + strategy_list.append(self.split_input_channel_1d(0, 1)) + + # The strategies with SYNC_BN are temporarily commented, + # because it requires some additional passes to keep runtime + # computation correctness. + + # TODO: The strategies below should be uncommented after runtime + # passes ready. + # SR = SR x R WITH SYNC_BN + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + # SS = SS x S WITH SYNC_BN + strategy_list.append(self.split_input_both_dim(0, 1)) + strategy_list.append(self.split_input_both_dim(1, 0)) + + # S01R = S01R x R WITH SYNC_BN + strategy_list.append(self.split_input_batch_1d(0, 1)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7f811c8972412eaec88bb1dcfc639cdf1fe630 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -0,0 +1,111 @@ +import operator +from functools import reduce +from typing import List + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import StrategyGenerator + +__all__ = ['BinaryElementwiseStrategyGenerator'] + + +class BinaryElementwiseStrategyGenerator(StrategyGenerator): + """ + An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations + which have two operands and broadcasting occurs such as torch.add. + + The logical shape for this operation will be `input other`. + """ + + def validate(self) -> bool: + assert len(self.op_data) == 3, \ + f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + for name, op_data in self.op_data.items(): + if not isinstance(op_data.data, (torch.Tensor, int, float)): + raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # since elementwise ops are not compute-intensive, + # we approximate the backward compute cost + # to be twice the fwd compute cost + fwd_compute_cost = reduce(operator.mul, shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + # all input, output and outputs have the same shape + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # compute fwd memory cost in bytes + # as the elementwise ops are not memory-intensive + # we approximate the fwd memroy cost to be the output + # and the backward memory cost to be grad of input and other + input_bytes = self._compute_size_in_bytes(strategy, 'input') + other_bytes = self._compute_size_in_bytes(strategy, 'other') + output_bytes = self._compute_size_in_bytes(strategy, 'output') + fwd_memory_cost = MemoryCost(activation=output_bytes) + bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) + total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(input=dim_partition_dict, + other=dim_partition_dict, + output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + sharding_seq = sharding_spec_mapping['input'].sharding_sequence + name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + return strategy_list + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = self.enumerate_all_possible_output(0, 1) + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c2154b3104d3d52e994a2add25ddc796792e1c66 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -0,0 +1,584 @@ +import copy +import operator +import warnings +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class ConvStrategyGenerator(StrategyGenerator): + """ + ConvStrategyGenerator is a generic class to generate strategies. + The operation data is defined as `output = input x other + bias`. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Conv1d, the dim of input data should be 3([N, C, L]). + For Conv2d, the dim of input data should be 4([N, C, H, W]). + For Conv3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (L) * N * Cout * Cin * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_output_shape) + + output_size = sharded_output_shape[2:] + output_size_product = reduce(operator.mul, output_size) + input_size = sharded_input_shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) + kernel_size = sharded_other_shape[2:] + kernel_size_product = reduce(operator.mul, kernel_size, 1) + batch_size = sharded_input_shape[0] + channel_in = sharded_input_shape[1] + channel_out = sharded_other_shape[1] + + forward_compute_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product + + backward_activation_cost = input_size_product * batch_size * channel_in * channel_out * kernel_size_product + backward_weight_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product + backward_compute_cost = backward_weight_cost + backward_activation_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping = {"input": input_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "other": { + 0: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + + communication_action_mapping = {"output": output_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping["other"] = other_comm_action + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0], + }, + "other": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "output": { + 1: [mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"output": output_comm_action, "input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_in_channel_weight_in_channel(self, mesh_dim_0): + name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0], + }, + "other": { + 0: [mesh_dim_0], + }, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + + communication_action_mapping = {"output": output_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_weight_out_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0], + }, + "output": { + 1: [mesh_dim_0], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_0], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1], + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1], + }, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) + + communication_action_mapping = {"output": output_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_0, mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + # SS = SR x RS + strategies.append(self.split_input_batch_weight_out_channel(0, 1)) + strategies.append(self.split_input_batch_weight_out_channel(1, 0)) + + # SR = SR x RR + strategies.append(self.split_input_batch(0)) + strategies.append(self.split_input_batch(1)) + + # SR = SS x SR + strategies.append(self.split_input_both_dim_weight_in_channel(0, 1)) + strategies.append(self.split_input_both_dim_weight_in_channel(1, 0)) + + # RS = RS x SS + strategies.append(self.split_input_in_channel_weight_both_channel(0, 1)) + strategies.append(self.split_input_in_channel_weight_both_channel(1, 0)) + + # RR = RS x SR + strategies.append(self.split_input_in_channel_weight_in_channel(0)) + strategies.append(self.split_input_in_channel_weight_in_channel(1)) + + # RS = RR x RS + strategies.append(self.split_weight_out_channel(0)) + strategies.append(self.split_weight_out_channel(1)) + + # RR= RR x RR + strategies.append(self.non_split()) + + # S01R = S01R x RR + strategies.append(self.split_1d_parallel_on_input_batch(0, 1)) + + # RR = RS01 x S01R + strategies.append(self.split_1d_parallel_on_in_channel(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_1d_parallel_on_out_channel(0, 1)) + + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..82a04ab52e739ae3db29efde2a66f30ff24cb8d0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py @@ -0,0 +1,310 @@ +import copy +import operator +import warnings +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class EmbeddingStrategyGenerator(StrategyGenerator): + """ + EmbeddingStrategyGenerator is a generic class to generate strategies for nn.Embedding or F.embedding. + The operation data is defined as `output = input x other`. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: The computation cost for the embedding handler is estimated as dense computing now. + It may not be accurate. + ''' + # TODO: estimate the embedding computation cost as sparse operation + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + + input_size_product = reduce(operator.mul, sharded_input_shape) + other_size_product = reduce(operator.mul, sharded_other_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = input_size_product * other_size_product + + backward_activation_cost = other_size_product * output_size_product / sharded_output_shape[-1] + backward_weight_cost = input_size_product * other_size_product + backward_compute_cost = backward_weight_cost + backward_activation_cost + + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def non_split(self): + name = f'RR = R x RR' + + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_input(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + }, + "other": { + 1: [mesh_dim_1], + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping = {"input": input_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + communication_action_mapping = {} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_embedding_dim(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0], + }, + "output": { + 1: [mesh_dim_0], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # RR= R x RR + strategies.append(self.non_split()) + + # SR = S x RR + strategies.append(self.split_input(0)) + strategies.append(self.split_input(1)) + + # SS = S x RS + strategies.append(self.split_input_and_embedding_dim(0, 1)) + strategies.append(self.split_input_and_embedding_dim(1, 0)) + + # S01R = S01 x RR + strategies.append(self.split_1d_parallel_on_input(0, 1)) + + # RS = R x RS + strategies.append(self.split_embedding_dim(0)) + strategies.append(self.split_embedding_dim(1)) + + # RS01 = R x RS01 + strategies.append(self.split_1d_parallel_on_embedding_dim(0, 1)) + + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..753ab1726d4ce5c84dca66390c3258a04f8355d5 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -0,0 +1,53 @@ +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +from .strategy_generator import StrategyGenerator + +__all__ = ['GetattrGenerator'] + + +class GetattrGenerator(StrategyGenerator): + """ + PlaceholderGenerator is a generic class to generate strategies for placeholder node. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Attribute' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return [strategy] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2795c85449e9ecfe52067ed4c4154ca88ea2cf1f --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -0,0 +1,147 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import FollowingStrategyGenerator + +__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] + + +class GetItemStrategyGenerator(FollowingStrategyGenerator): + """ + GetItemStrategyGenerator is a generic class to generate strategies for operator.getitem. + The operation data is defined as `output = input[other]`. + + There are mainly three use cases: + 1. args_0._meta_data: torch.Tensor, args_1._meta_data: int + 2. args_0._meta_data: torch.Tensor, args_1._meta_data: slice + 3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + +class TensorStrategyGenerator(GetItemStrategyGenerator): + ''' + Deal with case 1 and 2. + ''' + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + gather_input = 0 in dim_partition_dict_for_input + if gather_input: + logical_process_axis = dim_partition_dict_for_output.pop(0) + + shift_dim_partition_dict_for_output = {} + for dim, mesh_dim_list in dim_partition_dict_for_output.items(): + shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list + dim_partition_dict_for_output = shift_dim_partition_dict_for_output + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + if gather_input: + input_communication_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=logical_process_axis, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping["input"] = input_communication_action + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list + + +class TensorTupleStrategyGenerator(GetItemStrategyGenerator): + ''' + Deal with case 3. + ''' + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + index = self.op_data["index"].data + + for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector): + # the sharding spec for input in this case is a tuple of ShardingSpec. + sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_mapping = { + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + sharding_spec_mapping["input"] = sharding_spec_for_input + input_sharding_info = f"get the {index} element from (" + for sharding_spec in sharding_spec_for_input: + input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += ")" + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb6070f7e82c9a41848c626c6271d1a7b9d73ee --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -0,0 +1,195 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + +__all__ = ['LayerNormGenerator'] + + +class LayerNormGenerator(StrategyGenerator): + """ + LayerNormGenerator is a generic class to generate strategies for LayerNorm operation. + The operation data is defined as `output = input x other + bias`. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # TODO: a constant coefficient need to be added. + + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_weight_shape) + # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. + input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)] + input_batch_product = reduce(operator.mul, input_batch_shape, 1) + norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1) + forward_compute_cost = input_batch_product * norm_kernel_product + backward_activation_compute_cost = input_batch_product * norm_kernel_product + # To compute gradient of on norm kernel element requires input_batch_product times computation, so + # the total cost is input_batch_product * norm_kernel_product + backward_weight_compute_cost = input_batch_product * norm_kernel_product + backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = { + "input": dim_partition, + "other": {}, + "output": dim_partition, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence} x {sharding_spec_mapping["other"].sharding_sequence}' + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + communication_action_mapping = {} + + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def split_input_batch_single_mesh_dim(self, mesh_dim_0, batch_dimension_length): + strategy_list = [] + dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + return strategy_list + + def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimension_length): + strategy_list = [] + dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + return strategy_list + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x R' + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + ''' + Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. + ''' + strategy_list = [] + input_data_dim = len(self.op_data["input"].logical_shape) + weight_data_dim = len(self.op_data["other"].logical_shape) + # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. + batch_dimension_length = input_data_dim - weight_data_dim + + # SR = SR x R with single mesh dim on batch dimensions + strategy_list.extend(self.split_input_batch_single_mesh_dim(0, batch_dimension_length)) + strategy_list.extend(self.split_input_batch_single_mesh_dim(1, batch_dimension_length)) + + # SR = SR x R with both mesh dims on batch dimensions + strategy_list.extend(self.split_input_batch_both_mesh_dim(0, 1, batch_dimension_length)) + + # RR = RR x R + strategy_list.append(self.non_split()) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2246f952a984aad89665e9ba62a56beb8a482b --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -0,0 +1,994 @@ +import operator +from ast import arg +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class MatMulStrategyGenerator(StrategyGenerator): + """ + MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases. + The operation data is defined as `output = input x other + bias`. + """ + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + size_mapping['bias'] = bias_size + + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bias_grad + bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + 0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + +class DotProductStrategyGenerator(MatMulStrategyGenerator): + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + @ignore_sharding_exception + def no_split(self): + name = f'R = R dot R' + dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + communication_action_mapping = {} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_one_dim(self, mesh_dim): + name = f'R = S{mesh_dim} dot S{mesh_dim}' + + # get sharding spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication action + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) + communication_action_mapping = {"output": output_comm_action} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + + # do not split dimensions for dot product + # R = R dot R + strategy_list.append(self.no_split()) + + # split two tensors in the same dimensions + # S = S dot S + strategy_list.append(self.split_one_dim(0)) + strategy_list.append(self.split_one_dim(1)) + + return strategy_list + + +class MatVecStrategyGenerator(MatMulStrategyGenerator): + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + @ignore_sharding_exception + def no_split(self): + name = "R = R x R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + + if self.has_bias: + dim_partition_dict['bias'] = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim): + name = f'S{mesh_dim}R = S{mesh_dim}R x R' + + # get sharding spec + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": {}, + "output": { + 0: [mesh_dim] + }, + } + + if self.has_bias: + dim_partition_dict['bias'] = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication action + communication_action_mapping = {} + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=2) + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + + # no split + strategy_list.append(self.no_split()) + + # split the batch dim for the first tensor only + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + return strategy_list + + +class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): + + def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'): + super().__init__(operation_data_mapping, device_mesh) + self.linear_projection_type = linear_projection_type + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + # C = AB + # C: [M, N], A: [M, P], B: [P, N] + # fwd cost = MNP (only count mul) + # bwd: 2 x fwd_cost + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) + dim_n_val = sharded_other_shape[-1] + dim_p_val = sharded_other_shape[0] + + fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=bwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # SS = SR x RS + strategies.append(self.split_lhs_space_rhs_space(0, 1)) + strategies.append(self.split_lhs_space_rhs_space(1, 0)) + + # SR = SS x SR + strategies.append(self.split_lhs_space_both_contract(0, 1)) + strategies.append(self.split_lhs_space_both_contract(1, 0)) + + # RS = RS x SS + strategies.append(self.split_rhs_space_both_contract(0, 1)) + strategies.append(self.split_rhs_space_both_contract(1, 0)) + + # RR= RS x SR + strategies.append(self.recompute_split_both_contract(0)) + strategies.append(self.recompute_split_both_contract(1)) + + # RS = RR x RS + strategies.append(self.split_rhs_space_only(0)) + strategies.append(self.split_rhs_space_only(1)) + + # S01R = S01R x RR + strategies.append(self.split_lhs_1st_dim_1d(0, 1)) + + # RR = RS01 x S01R + strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + + # RR = RR x RR + strategies.append(self.non_split()) + + return strategies + + @ignore_sharding_exception + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['input'] = input_comm_action + communication_action_mapping['other'] = other_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + # get sharding spec mapping + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0] + }, + } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + else: + raise ('Unsupported linear projection type') + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action mapping + communication_action_mapping = {} + + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['other'] = other_comm_action + communication_action_mapping['output'] = output_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + # get sharding specs + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + -1: [mesh_dim_1] + }, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication actions + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping["input"] = input_comm_action + communication_action_mapping['output'] = output_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + + # get sharding spec + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim] + }, + "other": { + 0: [mesh_dim] + }, + "bias": {}, + "output": {}, + } + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) + + communication_action_mapping['output'] = output_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": { + -1: [mesh_dim] + }, + "bias": { + -1: [mesh_dim] + }, + "output": { + -1: [mesh_dim] + }, + } + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication actions + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping['input'] = input_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + # get sharding spec + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + # get sharding spec + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": {}, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "output": { + -1: [mesh_dim_0, mesh_dim_1] + }, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "bias": {}, + "output": {}, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def validate(self) -> bool: + assert "input" in self.op_data + assert "other" in self.op_data + + # make sure the other has 2 dim + input_data = self.op_data['input'] + other_data = self.op_data['other'] + assert input_data.data.dim() > 0 and other_data.data.dim() == 2 + assert other_data.logical_shape[0] == input_data.logical_shape[-1] + + if self.has_bias: + bias_data = self.op_data['bias'] + assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] + + +class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): + """ + Generate sharding strategies for the batched matrix multiplication. + + A batched matrix multiplication can be viewed as + [b, i, k] x [b, k, j] -> [b, i, j] + + The bias term is considered to have a 2D logical shape. + + Note: This class will be used to generate strategies for torch.bmm + and torch.addbmm. However, the result of torch.addbmm is not correct, + some extra runtime apply actions are required to keep numerical correctness. + """ + + # TODO: torch.addbmm correctness issue need to be fixed. + def __init__(self, *args, **kwargs): + self.squeeze_batch_dim = False + super().__init__(*args, **kwargs) + + def _pop_batch_dim_sharding_for_output(self, dim_partition_dict): + # remove partition dict for dim 0 + dim_partition_dict['output'].pop(0, None) + + # decrease the remaining dim index by 1 + temp_dim_partition = {} + keys = list(dim_partition_dict['output'].keys()) + for key in keys: + val = dim_partition_dict['output'].pop(key) + temp_dim_partition[key - 1] = val + dim_partition_dict['output'].update(temp_dim_partition) + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 + + if 'bias' in self.op_data: + bias_op_data = self.op_data['bias'] + assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 + + if self.op_data['output'].data.dim() == 2: + # addbmm will shrink the first batch dim + self.squeeze_batch_dim = True + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, + self.op_data['output'].data.shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + @ignore_sharding_exception + def split_one_batch_dim(self, mesh_dim): + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + + # get sharding_spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0] + }, + "bias": { + 0: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the third argument instead of second. + communication_action_mapping['other'].arg_index += 1 + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + 2: [mesh_dim_1] + }, + "bias": { + 1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 2: [mesh_dim_1] + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the second argument instead of first. + communication_action_mapping['input'].arg_index += 1 + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + 2: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0], + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + device_mesh_is_1d = True + if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + device_mesh_is_1d = False + + if device_mesh_is_1d: + # split only the batch dimension + # Sb = Sb x Sb + # can be None as it is only for 1D device mesh + # only for 1D device mesh + if len(self.device_mesh.mesh_shape) == 1: + mesh_dim = 0 + else: + mesh_dim = self.device_mesh.mesh_shape.index(1) + strategy_list.append(self.split_one_batch_dim(mesh_dim)) + else: + # for 2D device mesh + # split batch dim of two inputs and the i dim of the first tensor + # SbSi = SbSi x Sb + strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) + + # split batch dim of two inputs and the j of the second tensor + # SbSj = Sb x SbSj + strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) + + # split batch dim of two inputs and the k dim of two inputs + # Sb = SbSk x SbSk, need to all-reduce by k dim + strategy_list.append(self.split_batch_dim_both_contract(0, 1)) + strategy_list.append(self.split_batch_dim_both_contract(1, 0)) + + # split two batch dim + strategy_list.append(self.split_two_batch_dim(0, 1)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9df6d2fbfa127b71eba66256fb27f204fb1da5fe --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -0,0 +1,118 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) + +from .strategy_generator import StrategyGenerator + + +class NormalPoolStrategyGenerator(StrategyGenerator): + """ + NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd. + The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image, + and reduce them depening on the operation type. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Pool1d, the dim of input data should be 3([N, C, L]). + For Pool2d, the dim of input data should be 4([N, C, H, W]). + For Pool3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (Lout) * N * C * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + kernel_size = self.op_data["other"].data + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * (len(sharded_output_shape) - 2) + kernel_size_product = reduce(operator.mul, kernel_size) + output_size_product = reduce(operator.mul, sharded_output_shape) + input_size_product = reduce(operator.mul, sharded_input_shape) + + forward_compute_cost = output_size_product * kernel_size_product + backward_compute_cost = input_size_product * kernel_size_product + + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + communication_action_mapping = {} + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1): + dim_partition_list = [] + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2)) + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2)) + dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2)) + # append {} for non_split case + dim_partition_list.append({}) + + return dim_partition_list + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + + dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..69d1642d4f808038d0eeb58547a7b1c0604c85eb --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -0,0 +1,121 @@ +from typing import Dict, List + +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh + +from .strategy_generator import OutputStrategyGenerator + +__all__ = ['OutputGenerator'] + + +class OutputGenerator(OutputStrategyGenerator): + """ + OutputGenerator is a generic class to generate strategies for Output Node. + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_nodes: List[Node], output_option: str): + super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) + self.output_option = output_option + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + fwd_mem_cost = MemoryCost(activation=0, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=0, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def replica_strategy(self) -> List[ShardingStrategy]: + """ + Generate replica strategy for output node. + """ + dim_partition_dict_mapping = {} + dim_partition_dict_for_output = [] + for index, _ in enumerate(self.predecessor_nodes): + mapping_name = f"input_{index}" + if isinstance(self.op_data[mapping_name].data, (tuple, list)): + dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))] + else: + dim_partition_dict_for_input = {} + dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input + dim_partition_dict_for_output.append(dim_partition_dict_for_input) + + if len(dim_partition_dict_for_output) == 1: + dim_partition_dict_for_output = dim_partition_dict_for_output[0] + else: + dim_partition_dict_for_output = tuple(dim_partition_dict_for_output) + + dim_partition_dict_mapping['output'] = dim_partition_dict_for_output + + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Output' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + return strategy + + def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: + """ + Generate distributed strategy for output node. + """ + # TODO: need to take care of the case when the first element of output only need to be sharded. + output_op_data = self.op_data['output'] + if isinstance(output_op_data.data, tuple): + length = len(output_op_data.data) + dim_partition_dict_mapping = { + "output": [{ + 0: mesh_list + }] * length, + } + else: + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + for index, _ in enumerate(self.predecessor_nodes): + mapping_name = f"input_{index}" + dim_partition_dict_mapping[mapping_name] = {0: mesh_list} + + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Output' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + mesh_list = [0, 1] + if self.output_option == 'replicated': + strategy_list.append(self.replica_strategy()) + elif self.output_option == 'distributed': + strategy_list.append(self.distributed_strategy(mesh_list)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..779a7ced93bb503c390bd89382d087230e48d2f0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -0,0 +1,100 @@ +from typing import Dict, List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh + +from .strategy_generator import StrategyGenerator + +__all__ = ['PlaceholderGenerator'] + + +class PlaceholderGenerator(StrategyGenerator): + """ + PlaceholderGenerator is a generic class to generate strategies for placeholder node. + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + placeholder_option: str): + super().__init__(operation_data_mapping, device_mesh) + self.placeholder_option = placeholder_option + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def replica_placeholder(self) -> ShardingStrategy: + """ + Generate replica strategy for placeholder node. + """ + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Placeholder' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def distributed_placeholder(self, mesh_list) -> ShardingStrategy: + """ + Generate distributed strategy for placeholder node. + """ + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Placeholder' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + if self.placeholder_option == 'distributed': + mesh_list = [0, 1] + distributed_strategy = self.distributed_placeholder(mesh_list) + strategy_list.append(distributed_strategy) + else: + assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + replicated_strategy = self.replica_placeholder() + strategy_list.append(replicated_strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3506c27e4c088448500704d3675177fb001795 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -0,0 +1,122 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +from .strategy_generator import FollowingStrategyGenerator + +__all__ = ['ReshapeGenerator'] + + +class ReshapeGenerator(FollowingStrategyGenerator): + """ + ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + # For reshape function, to keep the computing correctness we keep the sharding + # spec of input is fully replicated. In addition, we will keep the output in + # replica status and let the successor node choose the way to resharding the + # output node. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for reshape function. + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + if isinstance(self.op_data["output"].data, tuple): + dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))] + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + input_comm_action.comm_spec.gather_dim = total_mesh_dim_list + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ebadd043e2c2e563fcdc611567bca3ededfa51 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -0,0 +1,104 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +__all__ = ['SoftmaxGenerator'] + + +class SoftmaxGenerator(FollowingStrategyGenerator): + """ + SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + ''' + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = output_size_product * 2 + backward_compute_cost = input_size_product + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + softmax_dim = self.op_data['softmax_dim'].data + + if softmax_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6d68521aaea7989c085c24f32a8dc92f4b1b71fc --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -0,0 +1,298 @@ +import operator +from abc import ABC, abstractmethod +from functools import reduce +from typing import Any, Dict, List, Union + +import torch +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import convert_dim_partition_dict + + +class StrategyGenerator(ABC): + """ + StrategyGenerator is used to generate the same group of sharding strategies. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh): + self.op_data = operation_data_mapping + self.device_mesh = device_mesh + + # validate the whether operation data is of desired value + self.validate() + + @property + def has_bias(self): + """ + A utility method to check for the existence of bias operand for convenience. + """ + return 'bias' in self.op_data + + def is_param(self, op_data_name): + other_data = self.op_data[op_data_name] + return other_data.type == OperationDataType.PARAM + + def is_buffer(self, op_data_name): + other_data = self.op_data[op_data_name] + return other_data.type == OperationDataType.BUFFER + + def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], + communication_action_mapping: Dict[str, CommSpec]): + """ + A factory method to produce a ShardingStrategy object. + + Args: + sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object. + communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object. + """ + sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping) + communication_actions = self.replace_op_name_with_op_data(communication_action_mapping) + return ShardingStrategy(name=name, sharding_specs=sharding_specs, communication_actions=communication_actions) + + def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): + """ + A utility method to convert the the dim partition dict to a ShardingSpec object. + + Args: + mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary. + + Notes: + The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data. + However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as + list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned. + """ + results = {} + for op_data_name, dim_partition_dict in mapping.items(): + if op_data_name in self.op_data: + op_data = self.op_data[op_data_name] + + def _to_sharding_spec( + data: any, logical_shape: any, + dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + """ + This is a recursive function to convert the dim partition dict to a ShardingSpec object. + """ + if isinstance(data, torch.Tensor): + dim_size = len(logical_shape) + dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict) + return sharding_spec + elif isinstance(data, (list, tuple)): + sharding_spec = [] + for data_element, logical_shape_element, dim_partition_dict_element in zip( + data, logical_shape, dim_partition_dict): + sharding_spec.append( + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + return sharding_spec + else: + return None + + sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict) + results[op_data_name] = sharding_spec + return results + + def replace_op_name_with_op_data(self, mapping: Dict[str, Any]): + """ + Convert the key of the dictionary from the operation data name to an OperationData object. + """ + results = {} + for k, v in mapping.items(): + op_data = self.op_data[k] + results[op_data] = v + return results + + def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]]): + """ + A factory method to produce a CommSpec object. + """ + return CommSpec(comm_pattern=communication_pattern, + sharding_spec=sharding_spec, + logical_process_axis=logical_process_axis) + + def get_communication_action(self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + comm_type: CommType, + arg_index: int = -1, + key_for_kwarg: any = None) -> CommAction: + """ + A factory method to produce a CommAction object. + """ + return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec, + communication_pattern=communication_pattern, + logical_process_axis=logical_process_axis), + comm_type=comm_type, + arg_index=arg_index, + key_for_kwarg=key_for_kwarg) + + def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + """ + Compute the communication cost involved in the forward and backward iteration. + """ + + comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + + def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): + num_ele_in_comm = comm_spec.get_comm_cost() + dtype = op_data.data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + for phase, cost in num_ele_in_comm.items(): + num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes + comm_cost.fwd += num_ele_in_comm['forward'] + comm_cost.bwd += num_ele_in_comm['backward'] + comm_cost.total += num_ele_in_comm['total'] + + # check if communication action exists + # if so, loop over each action and compute the cost of each action + if strategy.communication_actions is not None: + for operand, comm_action in strategy.communication_actions.items(): + if isinstance(comm_action, CommAction): + comm_spec = comm_action.comm_spec + else: + # this condition branch will be removed after all the handler updated. + comm_spec = comm_action + if isinstance(comm_spec, dict): + src_spec = comm_spec['src_spec'] + tgt_spec = comm_spec['tgt_spec'] + shape_consistency_manager = ShapeConsistencyManager() + _, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec) + for comm_spec_ in comm_action_sequence: + _compute_and_add(operand, comm_spec_) + else: + _compute_and_add(operand, comm_spec) + + # update the communication cost attribute in-place + strategy.communication_cost = comm_cost + return strategy + + @abstractmethod + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + """ + Customize this method to compute the computation flops. + """ + pass + + @abstractmethod + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + """ + Customize this method to compute the memory cost in bytes. + """ + pass + + def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): + """ + Compute the size of a tensor in bytes. + + Args: + strategy (ShardingStrategy): the ShardingStrategy generated. + key (str): the name of the operation data defined by the generator. + """ + op_data = self.op_data[key] + + def _compute_size_in_bytes_helper(sharding_spec, meta_data): + sharded_shape = sharding_spec.get_sharded_shape_per_device() + if len(sharded_shape) == 0: + num_elements = 1 + else: + num_elements = reduce(operator.mul, sharded_shape) + dtype = getattr(meta_data, 'dtype') + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + return num_elements * size_per_elem_bytes + + if isinstance(op_data.data, tuple): + assert isinstance(strategy.sharding_specs[op_data], list), \ + 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + total_bytes = 0 + for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): + meta_data = op_data.data[index] + if isinstance(meta_data, torch.Tensor): + element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data) + else: + # if meta_data is not a tensor, we count the memroy as 0 + element_bytes = 0 + total_bytes += element_bytes + + else: + if isinstance(op_data.data, torch.Tensor): + total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data) + else: + # if op_data.data is not a tensor, we count the memroy as 0 + total_bytes = 0 + + return total_bytes + + def generate(self) -> List[ShardingStrategy]: + """ + Generate all possible sharding strategies for this operation. + """ + strategies = self.collate_strategies() + + # some strategies may be None as ignore_sharding_exception may return None + # when ShardingSpecException occurs. + # thus, remove those None values + strategies = [strategy for strategy in strategies if strategy] + + # update the costs + # update mete info on cost + # these update methods are all in-place, the default method will do nothing + # the cost info will only be added if the child class overrides these methods + for strategy in strategies: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategies + + @abstractmethod + def collate_strategies(self) -> List[ShardingStrategy]: + pass + + @abstractmethod + def validate(self) -> bool: + """ + Validate if the operands are of desired shape. + If True, means this generator can be used for the current operation. + """ + pass + + +class FollowingStrategyGenerator(StrategyGenerator): + """ + FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_node: Node): + self.op_data = operation_data_mapping + self.device_mesh = device_mesh + self.predecessor_node = predecessor_node + + +class OutputStrategyGenerator(StrategyGenerator): + """ + OutputStrategyGenerator is used to generate the sharding strategies for Output Node. + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_nodes: List[Node]): + super().__init__(operation_data_mapping, device_mesh) + self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fbc58d70c0feba2c78305fb14d9bcb38a82e41 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -0,0 +1,113 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['SumGenerator'] + + +class SumGenerator(FollowingStrategyGenerator): + """ + SumGenerator deals with the sharding strategies of torch.sum op. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + compute_cost = TrainCycleItem(fwd=input_size_product, + bwd=output_size_product, + total=input_size_product + output_size_product) + + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + + # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce + # among all the shard groups + recover_dims = [] + dim_partition_dict_for_output = {} + for dim in dim_partition_dict_for_input: + if dim in sum_dims: + recover_dims.append(dim) + elif dim in sum_mapping_dict: + dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] + else: + raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + + for dim in recover_dims: + dim_partition_dict_for_input.pop(dim) + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..93cfc9eeea532ac4383f0821008deeccb13951d0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py @@ -0,0 +1,67 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +from .strategy_generator import StrategyGenerator + +__all__ = ['TensorConstructorGenerator'] + + +class TensorConstructorGenerator(StrategyGenerator): + """ + TensorConstructorGenerator which deals with + the sharding strategies for tensor constructor operation, such as torch.arange. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Tensor Constructor' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b867a30686eb97a55096895d344dcc28b51f347a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -0,0 +1,77 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) + +from .strategy_generator import FollowingStrategyGenerator + +__all__ = ['UnaryElementwiseGenerator'] + + +class UnaryElementwiseGenerator(FollowingStrategyGenerator): + """ + UnaryElementwiseGenerator which deals with the sharding strategies of UnaryElementwiseOp. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + # For element-wise function, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise function. + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fa941f2cc51dc4d817bfc8f49c54bbaf7a8a5407 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -0,0 +1,98 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) + +from .strategy_generator import StrategyGenerator + +__all__ = ['WhereGenerator'] + + +class WhereGenerator(StrategyGenerator): + """ + WhereGenerator is a generic class to generate strategies for Where operation. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'condition': self._compute_size_in_bytes(strategy, "condition"), + 'x': self._compute_size_in_bytes(strategy, "x"), + 'y': self._compute_size_in_bytes(strategy, "y"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = condition + x + y + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + # compute bwd cost incurred + # bwd_cost = condition_grad + x_grad + y_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = { + "condition": dim_partition, + "x": dim_partition, + "y": dim_partition, + "output": dim_partition + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' + communication_action_mapping = {} + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_length): + dim_partition_list = [] + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, dimension_length)) + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, dimension_length)) + dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dimension_length)) + # append {} for non_split case + dim_partition_list.append({}) + + return dim_partition_list + + def collate_strategies(self) -> List[ShardingStrategy]: + ''' + Generate every possible strategies for a where node, and record all strategies into the strategies_vector. + ''' + strategy_list = [] + + dimension_length = len(self.op_data["output"].logical_shape) + dim_partition_list = self.enumerate_all_possible_output_spec(0, 1, dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..86f90694e0604f72e9564020ccab455cfdee29a0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -0,0 +1,81 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, SumGenerator + +__all__ = ['SumHandler'] + + +@operator_registry.register(torch.Tensor.sum) +@operator_registry.register(torch.sum) +class SumHandler(NodeHandler): + """ + A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + if len(self.node.args) > 1: + sum_dims = self.node.args[1] + else: + sum_dims = tuple(range(self.node.args[0]._meta_data.dim())) + + if isinstance(sum_dims, int): + sum_dims = (sum_dims,) + + # recover negative value to positive + num_dims = self.node.args[0]._meta_data.dim() + for i in range(len(sum_dims)): + if sum_dims[i] < 0: + sum_dims[i] += num_dims + + # mapping the input dims to output dims + # For examples: + # input: torch.rand(2, 3, 4, 5) + # output: torch.sum(input, (0, 2)) + # sum_mapping_dict = {1: 0, 3: 1} + # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input + # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input + sum_mapping_dict = {} + if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + for i in range(num_dims): + sum_mapping_dict.update({i: i}) + else: + output_index = 0 + for i in range(num_dims): + if i not in sum_dims: + sum_mapping_dict.update({i: output_index}) + output_index += 1 + assert output_index == self.node._meta_data.dim() + + sum_info = (sum_dims, sum_mapping_dict) + physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "sum_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..855a2e7612af0cb59cae9bc8574197fad098f983 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py @@ -0,0 +1,32 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator +from .strategy.tensor_constructor_generator import TensorConstructorGenerator + +__all__ = ['TensorConstructorHandler'] + + +@operator_registry.register(torch.arange) +class TensorConstructorHandler(NodeHandler): + """ + A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TensorConstructorGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = {"output": physical_output_operand} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..bda1609065177da92a6f935aa434216a4cf9e94c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -0,0 +1,43 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, UnaryElementwiseGenerator + +__all__ = ['UnaryElementwiseHandler'] + + +@operator_registry.register(torch.Tensor.to) +@operator_registry.register(torch.Tensor.type) +@operator_registry.register(torch.abs) +@operator_registry.register(torch.nn.ReLU) +@operator_registry.register(torch.nn.Tanh) +@operator_registry.register(torch.tanh) +@operator_registry.register(torch.nn.modules.dropout.Dropout) +@operator_registry.register(torch.Tensor.contiguous) +@operator_registry.register(torch.nn.functional.dropout) +class UnaryElementwiseHandler(NodeHandler): + """ + A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6de2aaafdd018f08195563ef882f07eb39d8d20a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -0,0 +1,71 @@ +import copy +import operator +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, WhereGenerator + +__all__ = ['WhereHandler'] + + +@operator_registry.register(torch.where) +class WhereHandler(NodeHandler): + """ + A WhereHandler which deals with the sharding strategies for torch.where. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + logical_op_data_mapping, _ = self.get_operation_data_mapping() + generators = [] + generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_condition_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_x_operand = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.node.args[1]._meta_data) + physical_y_operand = OperationData(name=str(self.node.args[2]), + type=OperationDataType.ARG, + data=self.node.args[2]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + physical_mapping = { + "condition": physical_condition_operand, + "x": physical_x_operand, + "y": physical_y_operand, + "output": physical_output + } + logical_shape_for_all = self.node._meta_data.shape + logical_mapping = {} + for key, physical_operand in physical_mapping.items(): + logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, + logical_shape_for_all) + + return logical_mapping, physical_mapping + + def convert_physical_operand_to_logical_operand(self, physical_operand, target_shape): + logical_operand = copy.deepcopy(physical_operand) + logical_operand.logical_shape = target_shape + return logical_operand + + def post_process(self, strategy: ShardingStrategy): + logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() + for key in logical_op_data_mapping.keys(): + logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] + logical_shape = logical_op_data_mapping[key].logical_shape + physical_shape = physical_op_data_mapping[key].logical_shape + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec, logical_shape, physical_shape) + strategy.sharding_specs.pop(logical_op_data_mapping[key]) + strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec + strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..4929e09ad531515c9070750ad2dc36674d42b043 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -0,0 +1,277 @@ +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch.fx.node import Node + +from colossalai.tensor.shape_consistency import CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import ( + BCAST_FUNC_OP, + ELEMENTWISE_FUNC_OP, + ELEMENTWISE_METHOD_OP, + ELEMENTWISE_MODULE_OP, + RESHAPE_FUNC_OP, + RESHAPE_METHOD_OP, +) + +__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] + + +class OperationDataType(Enum): + """ + An operation can come from the argument list of an operator or the parameter list of a module. + """ + INPUT = 0 + ARG = 1 + PARAM = 2 + BUFFER = 3 + OUTPUT = 4 + + +@dataclass +class OperationData: + """ + OperationData is the data related to an operator, the data can be the operand or the output. + + Args: + name (str): the name of the operation-related data + type (OperationDataType): the type of the operation data + data (Any): the value for this data, usually it is a meta tensor. + logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. + """ + name: str + type: OperationDataType + data: Any + logical_shape: Tuple[int] = None + + def __post_init__(self): + # if no logical shape is specified, use the data shape as the logical shape + if self.logical_shape is None: + + def _infer_logical_shape(data: any): + """ + This function is used to infer the logical shape of the data. + """ + if isinstance(data, torch.Tensor): + return data.shape + elif isinstance(data, torch.Size): + return None + elif isinstance(data, (tuple, list)): + data_type = type(data) + return data_type([_infer_logical_shape(d) for d in data]) + else: + return None + + self.logical_shape = _infer_logical_shape(self.data) + + def __repr__(self) -> str: + return f'OperationData(name={self.name}, type={self.type})' + + def __eq__(self, other) -> bool: + return other.name == self.name + + def __hash__(self) -> int: + return hash(f'{self.name}') + + +@dataclass +class TrainCycleItem: + """ + TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass + in a training iteration. + + Args: + fwd (float): the item for the forward pass + bwd (float): the item for the backward pass + """ + fwd: Any + bwd: Any + total: Any + + +@dataclass +class MemoryCost: + """ + MemoryCost is a dataclass which stores the memory usage in the program. + + Args: + activation (int): the memory cost incurred by the activations in bytes. + parameter (int): the memory cost incurred by the module parameter in bytes. + temp (int): the memory cost incurred by the temporary tensors in bytes. + buffer (int): the memory cost incurred by the module buffer in bytes. + """ + activation: int = 0 + parameter: int = 0 + temp: int = 0 + buffer: int = 0 + + +class CommType(Enum): + """ + CommType describes the sequential order of a communication action and a computation action. + + Meaning: + BEFORE: the communication action happens just before the computation operation. + AFTER: the communication action happens after the computation operation. + HOOK: the communication action is used to do the grad all reduce. + IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm + """ + BEFORE = 0 + AFTER = 1 + HOOK = 2 + IMPLICIT = 3 + + +@dataclass +class CommAction: + """ + CommAction is used to record the communication action. + + Args: + comm_spec: express the communication pattern and the process groups to execute the communication action. + comm_type: describes the sequential order of a communication action and a computation action. + arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, + because the args of node may be changed by graph transform passes. + """ + comm_spec: CommSpec = None + comm_type: CommType = None + arg_index: int = -1 + key_for_kwarg: any = None + + +@dataclass +class ShardingStrategy: + """ + ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node. + + Args: + name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. + output_sharding_spec (ShardingSpec): ShardingSpec of the output node. + compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None) + communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None) + memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) + input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. + """ + name: str + sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None + compute_cost: TrainCycleItem = None + communication_cost: TrainCycleItem = None + memory_cost: TrainCycleItem = None + communication_actions: Dict[OperationData, CommAction] = None + resharding_costs: Dict[Node, List[TrainCycleItem]] = None + + @property + def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + specs = {} + specs.update(self._get_sharding_spec(OperationDataType.ARG)) + specs.update(self._get_sharding_spec(OperationDataType.PARAM)) + return specs + + @property + def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.ARG) + + @property + def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.PARAM) + + @property + def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.OUTPUT) + + def _get_sharding_spec(self, operation_data_type: OperationDataType): + specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type} + return specs + + def get_op_data_by_name(self, name: str): + for op_data in self.sharding_specs.keys(): + if op_data.name == name: + return op_data + raise KeyError(f"Could not find the OperationData with name {name}") + + def get_sharding_spec_by_name(self, name: str): + for op_data, sharding_spec in self.sharding_specs.items(): + if op_data.name == name: + return sharding_spec + raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") + + def clone(self): + + def _deepcopy_dict_vals(data: Dict): + return {k: deepcopy(v) for k, v in data.items()} + + sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None + # We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value. + # Consider the examples below: + # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. + # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. + communication_actions = _deepcopy_dict_vals( + self.communication_actions) if self.communication_actions is not None else None + # same reason as communication_actions + resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None + compute_cost = deepcopy(self.compute_cost) + communication_cost = deepcopy(self.communication_cost) + memory_cost = deepcopy(self.memory_cost) + + return ShardingStrategy(name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs) + + +class StrategiesVector(list): + ''' + Each node in fx graph will have a corresponding StrategiesVector, to store all the possible + strategies of the node. + + Argument: + node (Node): node for which the list of sharding strategies are generated. + ''' + + def __init__(self, node: Node): + super().__init__() + self.node = node + # fetch its input and output nodes + # TODO: placeholder input nodes + self.predecessor_nodes = list(node._input_nodes.keys()) + self.successor_nodes = list(node.users.keys()) + + def check_merge(self): + merge_label = False + if self.node.op == 'call_module': + target = self.node.target + root_module = self.node.graph.owning_module + submod = root_module.get_submodule(target) + submod_type = type(submod) + # merge elementwise module node into source nodes + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if submod_type in ELEMENTWISE_MODULE_OP: + merge_label = True + + if self.node.op == 'call_function': + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if self.node.target in ELEMENTWISE_FUNC_OP: + merge_label = True + # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. + # TODO: remove this after we support the fall back logic. + # if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: + # merge_label = True + # we could merge reshape op, because their computation costs are negligible. + if self.node.target in RESHAPE_FUNC_OP: + merge_label = True + + if self.node.op == 'call_method': + # we could merge reshape op, because their computation costs are negligible. + method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + if method in RESHAPE_METHOD_OP: + merge_label = True + if method in ELEMENTWISE_METHOD_OP: + merge_label = True + return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f9ba8814a79a9fb052bedb1ce0e5899ce6a679 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -0,0 +1,7 @@ +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .options import SolverOptions +from .solver import Solver +from .strategies_constructor import StrategiesConstructor + +__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions'] diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..038e56547b9664dec41170a5996e210090f15794 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -0,0 +1,208 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + + +class CostGraph: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies, simplify=True, forward_only=False): + self.leaf_strategies = leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + # stores number of strategies in each node + self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.following_dict = {} + self.simplify = simplify + self.forward_only = forward_only + self._build_cost_graph() + + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + if self.simplify: + self.merge_pair = [] + for strategies_vector in self.leaf_strategies: + # build edge_cost + dst_node = strategies_vector.node + for src_node in strategies_vector.predecessor_nodes: + if src_node not in self.nodes: + continue + node_pair = (src_node, dst_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(src_node.strategies_vector)): + resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j] + if self.forward_only: + edge_cost[(j, i)] = resharding_cost_item.fwd + else: + edge_cost[(j, i)] = resharding_cost_item.total + self.edge_costs[node_pair] = edge_cost + # add parents and children attribute to node + # parent_nodes = [node for node in strategies_vector.predecessor_nodes] + # children_nodes = [node for node in strategies_vector.successor_nodes] + parent_nodes = [] + children_nodes = [] + + def _check_tensor_in_node(data): + """ + This method is used to check whether the data has a tensor inside or not. + """ + has_tensor_flag = False + if isinstance(data, torch.Tensor): + return True + elif isinstance(data, (tuple, list)): + for d in data: + has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d) + return has_tensor_flag + + for node in strategies_vector.predecessor_nodes: + if _check_tensor_in_node(node._meta_data): + parent_nodes.append(node) + for node in strategies_vector.successor_nodes: + if _check_tensor_in_node(node._meta_data): + children_nodes.append(node) + + setattr(dst_node, 'parents', parent_nodes) + setattr(dst_node, 'children', children_nodes) + + if self.simplify and strategies_vector.check_merge(): + for followed_node in strategies_vector.predecessor_nodes: + # we only merge node pairs which src node has a tensor element inside. + # This is necessay because the node without a tensor element inside will not + # be assigned any strategy. + if _check_tensor_in_node(followed_node._meta_data): + self.merge_pair.append((followed_node, dst_node)) + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + def merge_node(self, src_node, dst_node): + ''' + To merge dst_node into src_node, we need to do it in following steps: + + 1. For each strategy in dst_node, we need to pick an appropriate strategy + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node + strategies dispatching. For example, for the graph 0->1->2, after merging node 1 + into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] + x represents the picking strategy of node 1 merged into node 2 strategy 0. + + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs + contains two parts, one is resharding costs between src_node strategy and dst_node strategy, + another is the origin extra costs in src_node strategy. + + 3. Build connections between new node pairs, and remove the src_node after all consumer nodes + detached from it. + + Argument: + src_node(Node): The node will be merged into dst_node. + dst_node(Node): The node to integrate src_node. + ''' + # build merge_map + merge_map = {} + for src_index, _ in enumerate(src_node.strategies_vector): + min_cost = INFINITY_COST + lowest_cost_index = -1 + for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): + resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index] + if self.forward_only: + resharding_cost = resharding_cost_item.fwd + else: + resharding_cost = resharding_cost_item.total + if resharding_cost <= min_cost: + min_cost = resharding_cost + lowest_cost_index = dst_index + merge_map[src_index] = lowest_cost_index + + # extra_node_cost for src node + self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] + for src_index, strategy in enumerate(src_node.strategies_vector): + target_strate_index = merge_map[src_index] + target_strategy = dst_node.strategies_vector[target_strate_index] + resharding_cost_item = target_strategy.resharding_costs[src_node][src_index] + if self.forward_only: + resharding_cost_to_add = resharding_cost_item.fwd + else: + resharding_cost_to_add = resharding_cost_item.total + self.extra_node_costs[src_node][src_index] += resharding_cost_to_add + if dst_node in self.extra_node_costs: + self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] + + # add new node pair to cost graph + for child_node in dst_node.children: + new_node_pair = (src_node, child_node) + old_node_pair = (dst_node, child_node) + if new_node_pair in self.edge_costs: + continue + edge_cost = {} + for i in range(self.node_lens[src_node]): + for j in range(self.node_lens[child_node]): + dst_strate_index = merge_map[i] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] + if new_node_pair not in self.edge_costs: + self.edge_costs[new_node_pair] = edge_cost + else: + # we should accumulate the resharding costs if args of child node contain + # both src node and dst node. + for index_pair, resharding_cost in self.edge_costs[new_node_pair]: + self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] + + # connect src node and children of dst node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + self.edge_costs.pop((src_node, dst_node)) + for child_node in dst_node.children: + if child_node not in src_node.children: + src_node.children.append(child_node) + if src_node not in child_node.parents: + child_node.parents.append(src_node) + # remove dst node from cost graph when dst node has no producer. + if len(dst_node.parents) == 0: + child_node.parents.remove(dst_node) + node_pair = (dst_node, child_node) + self.edge_costs.pop(node_pair) + if len(dst_node.parents) == 0: + self.following_dict[dst_node] = src_node + dst_node.children = [] + + def _reindexing_src(self, src): + if src not in self.following_dict: + return src + return self._reindexing_src(self.following_dict[src]) + + def simplify_graph(self): + if not self.simplify: + return + self.merge_pair.reverse() + for (src_node, dst_node) in self.merge_pair: + self.merge_node(src_node, dst_node) + self.merge_pair.reverse() + reindexing_following_dict = {} + for dst, src in self.following_dict.items(): + reindexing_following_dict[dst] = self._reindexing_src(src) + self.following_dict = reindexing_following_dict diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..be39a74cb23755f9ff2b83cf1123a7a7f9708ffa --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py @@ -0,0 +1,164 @@ +from dataclasses import dataclass +from typing import List + +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + +from colossalai.fx.passes.utils import get_node_module + +__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] + + +@dataclass +class LiveVariable: + """ + LiveVariable is a data structure to store the meta information of a variable for liveness analysis. + """ + name: str + node: Node + is_inplace: bool + + +class LiveVariableVector(list): + """ + LiveVariableVector is a data structure to store the list of LiveVariable objects. + """ + + def exists(self, name) -> bool: + """ + Check if a variable has already existed in the current list by name. + """ + for var in self: + if name == var.name: + return True + return False + + def get(self, name) -> LiveVariable: + for var in self: + if name == var.name: + return var + raise KeyError(f"Variable {name} is not found") + + def copy(self) -> "LiveVariableVector": + """ + Create a copy of this vector + """ + vector = LiveVariableVector() + for var in self: + vector.append(var) + return vector + + +@dataclass +class LiveStage: + """ + LiveStage is a data structure to record the living variables at this current node. + """ + name: str + node: Node + all_live_vars: LiveVariableVector + unique_live_vars: LiveVariableVector + + +class GraphAnalyser: + + def __init__(self, gm: GraphModule): + self._gm = gm + self._graph = gm.graph + + @property + def gm(self) -> GraphModule: + """ + Return the GraphModule object associated with this analyser. + """ + return self._gm + + @property + def graph(self) -> Graph: + """ + Return the Graph object associated with this analyser. + """ + return self._graph + + def liveness_analysis(self) -> List[LiveStage]: + """ + Analyse the graph to obtain the variable liveness information. This function returns + an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. + """ + compute_nodes = self.graph.nodes + liveness_list = [] + + # checked: record all variables created since the first stage + # all: record the live variables only exist until the current stage. + # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. + # unique: record the unique live variables only exist until the current stage. + # this is different from `all list` as some variables are duplicated. + checked_variables = LiveVariableVector() + all_live_variables = LiveVariableVector() + unique_live_vars = LiveVariableVector() + + for idx, node in enumerate(compute_nodes): + ############################# + # find new living variables # + ############################# + # detect whether the current op is an in-place op + # if it is an in-place op, we would deem it as a duplciate var + is_inplace = False + if node.op == 'call_function': + # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) + if node.kwargs.get('inplace', False): + is_inplace = True + elif node.op == 'call_module': + # to check if this is an inplace op such as torch.nn.Relu(inplace=True) + module = get_node_module(node) + if getattr(module, 'inplace', False): + is_inplace = True + + # add the output var + meta = getattr(node, '_meta_data', None) + live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) + if not is_inplace: + unique_live_vars.append(live_var) + checked_variables.append(live_var) + all_live_variables.append(live_var) + + # check if any input is not checked yet + for arg in node.args: + if not isinstance(arg, Node): + continue + arg_name = arg.name + if not checked_variables.exists(arg_name): + live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False) + all_live_variables.append(live_var_from_arg) + checked_variables.append(live_var_from_arg) + unique_live_vars.append(live_var_from_arg) + + # TODO: add the logic to remove live variables + # this should be completed if we are able to trace the backward compute graph + + # add this stage to liveness dict + stage = LiveStage(name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy()) + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + replace = False + for index, prev_stage in enumerate(liveness_list): + all_covered = True + for ele in prev_stage.unique_live_vars: + if ele not in stage.unique_live_vars: + all_covered = False + break + if all_covered: + replace = True + break + if replace: + liveness_list[index] = stage + else: + liveness_list.append(stage) + + return liveness_list + + def get_alias_set(self): + pass diff --git a/colossalai/auto_parallel/tensor_shard/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py new file mode 100644 index 0000000000000000000000000000000000000000..b52e55708dfde3a43e06ee43a42a8e47776c4d31 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/options.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from enum import Enum + +__all__ = ['SolverOptions'] + + +class SolverPerference(Enum): + """ + This enum class is to define the solver preference. + """ + STANDARD = 0 + DP = 1 + TP = 2 + + +class DataloaderOption(Enum): + """ + This enum class is to define the dataloader option. + """ + REPLICATED = 0 + DISTRIBUTED = 1 + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + solver_perference: SolverPerference = SolverPerference.STANDARD + dataloader_option: DataloaderOption = DataloaderOption.REPLICATED diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..89d0da2235a27f8554c0e33b0f0861bc357d46d2 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -0,0 +1,486 @@ +import multiprocessing +import time +import warnings +from typing import Dict + +import numpy as np +from torch.fx.graph import Graph +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .strategies_constructor import StrategiesConstructor + +try: + import pulp + from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum +except: + warnings.warn(f'please install the pulp') + +__all___ = ['Solver'] + + +class Solver: + + def __init__(self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser, + memory_budget: float = -1.0, + solution_numbers: int = 1, + forward_only: bool = False, + memory_increasing_coefficient: float = 1.3, + verbose=True): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.graph = graph + self.strategies_constructor = strategies_constructor + self.cost_graph = cost_graph + self.graph_analyser = graph_analyser + self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + self.strategy_map = self.strategies_constructor.strategy_map + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + self.forward_only = forward_only + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + self.liveness_list = self.graph_analyser.liveness_analysis() + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + self.verbose = verbose + + def _recover_merged_node_strategy(self): + ''' + During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. + Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged + node. + ''' + for node_index, node in enumerate(self.nodes): + if node.strategies_vector.check_merge(): + # the merged node has only one input, and its strategies follow the input sharding strategy + input_strategies_vector = node.args[0].strategies_vector + input_best_strategy_index = self.last_s_val[node_index - 1] + input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec + for strategy_index, strategy in enumerate(node.strategies_vector): + if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: + self.last_s_val[node_index] = strategy_index + break + + def _generate_node_index_dict(self) -> Dict[Node, int]: + node_index_dict = {} + for index, strategies_vector in enumerate(self.leaf_strategies): + node_index_dict[strategies_vector.node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append(self.cost_graph.node_lens[node]) + strategies_len = np.array(strategies_len) + + # prepare following_nodes + following_nodes = self.cost_graph.following_dict + index_following_nodes = {} + for src, target in following_nodes.items(): + src_index = self.node_index_dict[src] + target_index = self.node_index_dict[target] + index_following_nodes[src_index] = target_index + following_nodes = index_following_nodes + for index in range(node_nums): + if index not in following_nodes: + following_nodes[index] = -1 + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + + # prepare liveness_set + liveness_set = self.liveness_list + + # omit alias_set now + alias_set = None + alias_convert_costs = None + + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + extra_node_costs = self.cost_graph.extra_node_costs + for strategies_vector in self.leaf_strategies: + node = strategies_vector.node + for index, strategy in enumerate(strategies_vector): + compute_cost_item = strategy.compute_cost + communication_cost_item = strategy.communication_cost + memory_cost_item = strategy.memory_cost + + if self.forward_only: + origin_communication_cost = communication_cost_item.fwd + compute_cost = compute_cost_item.fwd + # extract MemoryCost item from the memory TrainCycleItem + memory_cost = memory_cost_item.fwd + else: + origin_communication_cost = communication_cost_item.total + compute_cost = compute_cost_item.total + # extract MemoryCost item from the memory TrainCycleItem + memory_cost = memory_cost_item.total + + # extract the memory cost in float from MemoryCost item and sum them up + memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer + compute_costs.append(compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + if node in extra_node_costs: + extra_node_cost = extra_node_costs[node][index] + communication_cost = origin_communication_cost + extra_node_cost + communication_costs.append(communication_cost) + else: + communication_costs.append(origin_communication_cost) + memory_costs.append(memory_cost) + + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array(memory_costs) + + # omit initial value for nodes + s_init_np = None + + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None, + verbose=True): + """ + Call the solver with serialized arguments. + """ + + tic = time.time() + + for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + v = [] + pt = 0 + + c = [] + d = [] + m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + assert len(e[idx]) == len(r[idx]) + for element in s: + assert len(element) > 0 + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + for i in range(node_nums): + assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + for i in range(len(E)): + assert len(e[i]) == len(r[i]) + obj += lpDot(e[i], r[i]) + + prob += obj + + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + for liveness_stage in liveness_set: + mem = 0 + for live_variable in liveness_stage.unique_live_vars: + if live_variable.node not in self.node_index_dict: + continue + node_index = self.node_index_dict[live_variable.node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + prob += mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] + + # (h) + ###################### + # omit alias set now # + ###################### + + # alias_set = set() + # for (idx, (i, j)) in enumerate(A): + # R = len(s[i]) # noqa + # C = len(s[j]) # noqa + # if (i, j) in alias_set: + # raise ValueError(f"Duplicated edges: {(i, j)}") + + # alias_set.add((i, j)) + # alias_set.add((j, i)) + + # for row in range(len(s[i])): + # for col in range(len(s[j])): + # if v[idx][row * C + col] > 0.5: + # prob += s[i][row] + s[j][col] <= 1 + + msg = verbose + time_limit = 600 + assert "COIN_CMD" in pulp.listSolvers( + onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) + # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + if verbose: + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" + f"Time: {time.time() - tic}") + print(f"#nodes: {num_nodes}, #edges: {num_edges}") + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget. " + "Please increase the memory budget.") + + # Get and check results + s_val = np.full((node_nums,), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E),), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = list(s_val) + # self._recover_merged_node_strategy() + self.last_objective = objective + + if objective > INFINITY_COST: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + return self.last_s_val, e_val, self.last_objective, status + + def call_solver_serialized_args(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1ff7fd13496e5e90534350a98cf282140a5587 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -0,0 +1,152 @@ +import builtins +import math +import operator +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx import Graph, Node + +from colossalai.auto_parallel.tensor_shard.node_handler import ( + GetattrHandler, + OuputHandler, + PlacehodlerHandler, + operator_registry, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector +from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec +from colossalai.device.device_mesh import DeviceMesh + +from .options import DataloaderOption, SolverOptions + +__all__ = ['StrategiesConstructor'] + + +class StrategiesConstructor: + """ + StrategiesConstructor is used to construct the parallelization plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. + """ + + def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.device_mesh = device_mesh + self.leaf_strategies = [] + self.strategy_map = {} + self.solver_options = solver_options + self.no_strategy_nodes = [] + + def remove_duplicated_strategy(self, strategies_vector): + ''' + In build_strategies_and_cost method, we may produce some duplicated strategies. + In this method, we will remove the duplicated strategies depending on the strategies name. + Note that this operation is in-place. + ''' + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + for strategy in remove_list: + strategies_vector.remove(strategy) + + def build_strategies_and_cost(self): + """ + This method is to build the strategy vector for each node in the computation graph. + """ + + def _check_no_strategy_for_node(node): + if node.op in ('placeholder', 'get_attr', 'output'): + return False + + def _check_no_strategy_for_data(data): + label = True + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (tuple, list)): + for d in data: + label = label and _check_no_strategy_for_data(d) + return label + + return _check_no_strategy_for_data(node._meta_data) + + for node in self.nodes: + strategies_vector = StrategiesVector(node) + + if _check_no_strategy_for_node(node): + self.no_strategy_nodes.append(node) + pass + + # placeholder node + elif node.op == 'placeholder': + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + placeholder_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + placeholder_option = 'replicated' + placeholder_handler = PlacehodlerHandler(node, + self.device_mesh, + strategies_vector, + placeholder_option=placeholder_option) + placeholder_handler.register_strategy() + + # get_attr node + elif node.op == 'get_attr': + getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) + getattr_handler.register_strategy() + + # call_module node + elif node.op == 'call_module': + target = node.target + submod = self.root_module.get_submodule(target) + submod_type = type(submod) + handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector) + handler.register_strategy() + + # call_function node + elif node.op == 'call_function': + target = node.target + handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector) + handler.register_strategy() + + # call_method node + elif node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector) + handler.register_strategy() + + # output node + elif node.op == 'output': + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + output_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + output_option = 'replicated' + output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) + output_handler.register_strategy() + + self.remove_duplicated_strategy(strategies_vector) + setattr(node, 'strategies_vector', strategies_vector) + self.leaf_strategies.append(strategies_vector) + self.strategy_map[node] = strategies_vector + + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fe5430bf136b08d93706b33cf4dbf82e342013 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -0,0 +1,25 @@ +from .broadcast import ( + BroadcastType, + comm_actions_for_oprands, + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) +from .factory import generate_resharding_costs, generate_sharding_spec +from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map +from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict +from .sharding import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + generate_sharding_size, + transpose_partition_dim, + update_partition_dim, +) + +__all__ = [ + 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', + 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' + 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', + 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' +] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..28aa551328d7a6d5f283338fe55c90eb102d253c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -0,0 +1,160 @@ +from enum import Enum, auto +from typing import List + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, +) +from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = [ + 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', + 'comm_actions_for_oprands' +] + + +class BroadcastType(Enum): + EQUAL = auto() + PADDDING = auto() + MULTIPLE = auto() + + +def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool: + """ + Check if two shapes are broadcastable to each other. + """ + for s1, s2 in zip(shape1[::-1], shape2[::-1]): + if s1 == 1 or s2 == 1 or s1 == s2: + pass + else: + return False + return True + + +def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: + """ + Compute the broadcast shape given two shapes. + """ + assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' + shape1_reverse = shape1[::-1] + shape2_reverse = shape2[::-1] + min_common_dim = min(len(shape1), len(shape2)) + dims = [] + for s1, s2 in zip(shape1_reverse, shape2_reverse): + dims.append(max(s1, s2)) + + # append the remaining dims + dims.extend(shape1_reverse[min_common_dim:]) + dims.extend(shape2_reverse[min_common_dim:]) + return dims[::-1] + + +def get_broadcast_dim_info(logical_shape, physical_shape): + # get the number of dimensions + logical_num_dims = len(logical_shape) + physical_num_dims = len(physical_shape) + + assert logical_num_dims >= physical_num_dims, \ + 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' + + # track the dim and its broadcasting type + logical_dim_broadcast_info = {} + + for i in range(logical_num_dims): + # get the trailing dim size + logical_dim_idx = logical_num_dims - i - 1 + phyiscal_dim_idx = physical_num_dims - i - 1 + logical_dim_size = logical_shape[logical_dim_idx] + + if phyiscal_dim_idx >= 0: + physical_dim_size = physical_shape[phyiscal_dim_idx] + + if physical_dim_size == logical_dim_size: + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL + elif physical_dim_size == 1 and physical_dim_size != logical_dim_size: + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE + else: + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING + + return logical_dim_broadcast_info + + +def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, + physical_shape: torch.Size) -> ShardingSpec: + """ + This function computes the sharding spec for the physical shape of a broadcast tensor. + + Args: + logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor + logical_shape (torch.Size): logical shape is the broadcast shape of a tensor + physical_shape (torch.Size): the shape of the tensor before broadcasting + """ + # if the two shapes are the same, no broadcast occurs + # we directly return the current sharding spec + + # recording the sharding dimensions removed during logical shape converting to physical one + removed_dims = [] + if list(logical_shape) == list(physical_shape): + return logical_sharding_spec, removed_dims + + # get the number of dimensions + logical_num_dims = len(logical_shape) + physical_num_dims = len(physical_shape) + + # get the broadcast info + logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape) + + # generate the sharding spec for the physical shape + physical_dim_partition = {} + logical_dim_partition = logical_sharding_spec.dim_partition_dict + + for shape_dim, mesh_dim in logical_dim_partition.items(): + logical_broadcast_type = logical_dim_broadcast_info[shape_dim] + + if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: + removed_dims.extend(mesh_dim) + else: + # get the corresponding physical dim + physical_dim = physical_num_dims - (logical_num_dims - shape_dim) + physical_dim_partition[physical_dim] = mesh_dim + + physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=physical_dim_partition) + + return physical_sharding_spec, removed_dims + + +def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, + sharding_spec: ShardingSpec) -> CommAction: + """ + This method is used to generate communication actions for oprands which lose information + during convert logical shape to physical shape. + """ + if len(removed_dims) == 1: + # if list length is 1, extract element from list to avoid using flatten device mesh + removed_dims = removed_dims[0] + comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims) + if op_data.type == OperationDataType.PARAM: + comm_type = CommType.HOOK + else: + comm_type = CommType.BEFORE + arg_index = -1 + for index, arg in enumerate(node.args): + if op_data.name == str(arg): + arg_index = index + assert arg_index >= 0, f'op_data should be an argument of node.' + comm_action = CommAction( + comm_spec=comm_spec, + comm_type=comm_type, + arg_index=arg_index, + ) + return comm_action diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3ba3d41c30ee685635b17d02367c7c8ae44d69 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -0,0 +1,90 @@ +import operator +import warnings +from functools import reduce +from typing import Dict, List, Optional, Union + +import torch +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from torch.fx.node import Node + +from ..constants import INFINITY_COST + +__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] + + +def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, + dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict. + + + Args: + input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding. + """ + + if isinstance(input_, Node): + assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' + meta_tensor = input_._meta_data + assert meta_tensor is not None, "The given node's _meta_data attribute is None" + shape = meta_tensor.shape + elif isinstance(input_, torch.Tensor): + shape = input_.shape + else: + raise TypeError( + f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + ) + for dim_index, sharding_index_list in dim_partition_dict.items(): + sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + + sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) + return sharding_spec + + +def generate_resharding_costs(nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None): + ''' + Compute the resharding costs with this specific strategy. + + Argument: + nodes (List[Node]): a list of nodes + sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. + count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + if not isinstance(input_sharding_spec, ShardingSpec): + assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + input_sharding_spec = input_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + try: + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes + except AssertionError as e: + warnings.warn(f'{e}') + resharding_cost = INFINITY_COST + resharding_costs[input_node].append(resharding_cost) + return resharding_costs diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..9e402dab757820c5d76ee6d1166de473c040784b --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -0,0 +1,97 @@ +import functools +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +import torch + +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException + +__all__ = ['ignore_sharding_exception', 'pytree_map'] + + +def ignore_sharding_exception(func): + """ + A function wrapper to handle the ShardingSpecException in the function. + If ShardingSpecException occurs, this function will return None. + + Usage: + # mute the assertion error in the function + @ignore_sharding_exception + def do_something(): + ... + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + logger = get_dist_logger() + rst = func(*args, **kwargs) + return rst + except ShardingSpecException as e: + logger.debug(e) + return None + + return wrapper + + +def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor): + """ + This function checks whether the ShardingSpec is valid for the physical tensor. + This check includes 3 items: + 1. the sharding spec covers all dimensions of the physical tensor + 2. the sharding spec for each dimension is divisible by the number of devices. + 3. the sharding spec's entire shape must match the tensor shape + # + """ + # make sure all dims are covered in sharding spec + sharding_len = len(sharding_spec.sharding_sequence) + tensor_num_dim = tensor.dim() + num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] + num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + assert sharding_len == tensor_num_dim, \ + f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + + # make sure the sharding is valid for each dim + for i in range(tensor_num_dim): + dim_size = tensor.shape[i] + dim_spec = sharding_spec.sharding_sequence[i] + + if str(dim_spec).startswith('S'): + devices_str = str(dim_spec).lstrip('S') + num_devices = 1 + + if '0' in devices_str: + num_devices *= num_devices_in_col + if '1' in devices_str: + num_devices *= num_devices_in_row + + assert dim_size >= num_devices and dim_size % num_devices == 0, \ + f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' + + # make sure the entire shape matches the physical tensor shape + assert sharding_spec.entire_shape == tensor.shape, \ + f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..a32a14bf7d577713ae2cb986ffbb42d87b0cabc1 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -0,0 +1,192 @@ +from enum import Enum +from typing import Dict, List, Tuple + +import torch + + +class PreviousStatus(Enum): + """ + This class shows the status of previous comparision. + """ + RESET = 0 + # ORIGIN means the dimension size of original tensor is larger in the previous comparision. + ORIGIN = 1 + # TGT means the dimension size of target tensor is larger in the previous comparision. + TGT = 2 + + +def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to detect the reshape mapping between original tensor and target tensor. + + Returns: + reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related + target dims(values) during reshaping operation. + Examples: + import torch + origin_shape = torch.Size([4, 4, 4]) + tgt_shape = torch.Size([2, 8, 2, 2]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + print(reshape_mapping_dict) + Output: + {(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)} + """ + + # reverse the shape object + origin_shape = list(origin_shape) + tgt_shape = list(tgt_shape) + origin_shape.reverse() + tgt_shape.reverse() + + # initialize arguments + reshape_mapping_dict = {} + origin_len = len(origin_shape) + tgt_len = len(tgt_shape) + origin_index = 0 + tgt_index = 0 + original_dimension_size = origin_shape[origin_index] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + origin_dims = [origin_len - origin_index - 1] + previous_label = PreviousStatus.RESET + + while origin_index != len(origin_shape) or tgt_index != len(tgt_shape): + if original_dimension_size == tgt_dimension_size: + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + # if the origin_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the origin_index for that case. + if len(origin_dims) > 0: + origin_index += 1 + # if the tgt_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the tgt_index for that case. + if len(tgt_dims) > 0: + tgt_index += 1 + # the last step of loop should always end with condition + # so we need to manually skip the preparation for next step + # in the last step. + if origin_index == len(origin_shape) and tgt_index == len(tgt_shape): + continue + + # If origin_index equals to origin_len, we just need to set the original_dimension_size + # to 1 to match the remaining '1's in the target tensor shape. + if origin_index == len(origin_shape): + original_dimension_size = 1 + origin_dims = [] + else: + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1] + + # If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size + # to 1 to match the remaining '1's in the original tensor shape. + if tgt_index == len(tgt_shape): + tgt_dimension_size = 1 + tgt_dims = [] + else: + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + + previous_label = PreviousStatus.RESET + + elif original_dimension_size > tgt_dimension_size: + tgt_index += 1 + + if previous_label == PreviousStatus.TGT: + # if the target dimension size is larger in the previous comparision, which means + # the origin dimension size has already accumulated larger than target dimension size, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + original_dimension_size = original_dimension_size // tgt_dimension_size + origin_dims = [origin_len - origin_index - 1] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size + tgt_dimension_size *= tgt_shape[tgt_index] + tgt_dims.append(tgt_len - tgt_index - 1) + previous_label = PreviousStatus.ORIGIN + + else: + origin_index += 1 + + if previous_label == PreviousStatus.ORIGIN: + # if the origin element is larger in the previous comparision, which means + # the target element has already accumulated larger than origin element, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + tgt_dimension_size = tgt_dimension_size // original_dimension_size + tgt_dims = [tgt_len - tgt_index - 1] + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1, origin_len - origin_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size + original_dimension_size *= origin_shape[origin_index] + origin_dims.append(origin_len - origin_index - 1) + previous_label = PreviousStatus.TGT + + return reshape_mapping_dict + + +def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: + """ + This method is used to check whether the reshape operation could implement without converting + the input to fully replicated status. + + Rule: + For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, + the function will return false. + To illustrate this issue, there are two cases to analyse: + 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal + operation without distributed tensor. + 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape + consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded + dim get recovered. + + Examples: + # the second dimension of the input has been sharded. + input_dim_partition_dict = {1: [1]} + origin_shape = torch.Size([8, 4, 2]) + tgt_shape = torch.Size([2, 4, 8]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + # {(2, 1): (2,), (0,): (1, 0)} + # the sharded dim of input is 1, which is the minimum element of the tuple (2, 1), + # so we do not have to convert the input to fully replicated status. + print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict)) + + Output: + True + """ + sharded_dims = list(input_dim_partition_dict.keys()) + for input_dims in reshape_mapping_dict.keys(): + # if input_dims has no element, we could just skip this iteration. + if len(input_dims) == 0: + continue + min_element = min(input_dims) + for dim in input_dims: + if dim in sharded_dims and dim is not min_element: + return False + return True + + +def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to infer the output dim partition dict for a reshape operation, + given the input dim partition dict and reshape mapping dict. + """ + assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ + 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + sharded_dims = list(input_dim_partition_dict.keys()) + output_dim_partition_dict = {} + for input_dims, output_dims in reshape_mapping_dict.items(): + for dim in input_dims: + if dim in sharded_dims: + output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim] + # we could break because input dims cannot contain two sharded dims, otherwise + # the keep sharding status check will fail. + break + return output_dim_partition_dict diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ce59e0b5772679be11e960322e3110c500d6aa --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -0,0 +1,120 @@ +import operator +from copy import deepcopy +from functools import reduce +from typing import Dict + +import torch + +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = [ + 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' +] + + +def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: + """ + Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place. + + Args: + sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched + dim1 (int): the tensor dimension to switch + dim2 (int): the tensor dimension to switch + """ + assert len(sharding_spec.entire_shape) >= 2, \ + 'The entire_shape of the sharding spec must have at least 2 dimensions' + dim_partition_dict = sharding_spec.dim_partition_dict + + # transpose the dim partition + dim1_partition = dim_partition_dict.pop(dim1, None) + dim2_partition = dim_partition_dict.pop(dim2, None) + + if dim1_partition: + dim_partition_dict[dim2] = dim1_partition + if dim2_partition: + dim_partition_dict[dim1] = dim2_partition + + # get the transposed shape + new_shape = list(sharding_spec.entire_shape[:]) + new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2] + new_shape = torch.Size(new_shape) + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict) + return sharding_spec + + +def update_partition_dim(sharding_spec: ShardingSpec, + dim_mapping: Dict[int, int], + physical_shape: torch.Size, + inplace: bool = False): + """ + This method is used to update the partition dim dict from the logical one to the physical one. + + Args: + sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated + dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension + physical_shape (torch.Size): the physical shape for the tensor + """ + + if inplace: + current_sharding_spec = sharding_spec + else: + current_sharding_spec = deepcopy(sharding_spec) + + old_dim_partition_dict = current_sharding_spec.dim_partition_dict + new_dim_partition_dict = {} + + # assign new dim + for old_dim, new_dim in dim_mapping.items(): + mesh_dims = old_dim_partition_dict.pop(old_dim) + new_dim_partition_dict[new_dim] = mesh_dims + + for tensor_dim, mesh_dims in old_dim_partition_dict.items(): + if tensor_dim in new_dim_partition_dict: + raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}") + else: + new_dim_partition_dict[tensor_dim] = mesh_dims + + # update sharding spec + current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=new_dim_partition_dict) + return current_sharding_spec + + +def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): + dim_partition_list = [] + # enumerate all the 2D sharding cases + for i in range(dim_size): + for j in range(i + 1, dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + dim_partition_list.append(dim_partition_dict_1) + for i in range(dim_size): + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + dim_partition_list.append(dim_partition_dict_flatten) + + return dim_partition_list + + +def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): + dim_partition_list = [] + # enumerate all the 1D sharding cases + for i in range(dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + + return dim_partition_list + + +def generate_sharding_size(dim_partition_dict, device_mesh): + total_sharding_size = 1 + for mesh_dim_list in dim_partition_dict.values(): + mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] + sharding_size = reduce(operator.mul, mesh_dim_sharding_size) + total_sharding_size *= sharding_size + + return total_sharding_size diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf09e1e7a31a15e979c5358c5da27683a0ccb2f9 --- /dev/null +++ b/colossalai/builder/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_from_config, build_from_registry, build_gradient_handler + +__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry'] diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4a907601327c9c938243bfee121165937c02537c --- /dev/null +++ b/colossalai/builder/builder.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect + +from colossalai.registry import * + + +def build_from_config(module, config: dict): + """Returns an object of :class:`module` constructed from `config`. + + Args: + module: A python or user-defined class + config: A python dict containing information used in the construction of the return object + + Returns: An ``object`` of interest + + Raises: + AssertionError: Raises an AssertionError if `module` is not a class + + """ + assert inspect.isclass(module), 'module must be a class' + return module(**config) + + +def build_from_registry(config, registry: Registry): + r"""Returns an object constructed from `config`, the type of the object + is specified by `registry`. + + Note: + the `config` is used to construct the return object such as `LAYERS`, `OPTIMIZERS` + and other support types in `registry`. The `config` should contain + all required parameters of corresponding object. The details of support + types in `registry` and the `mod_type` in `config` could be found in + `registry `_. + + Args: + config (dict or :class:`colossalai.context.colossalai.context.Config`): information + used in the construction of the return object. + registry (:class:`Registry`): A registry specifying the type of the return object + + Returns: + A Python object specified by `registry`. + + Raises: + Exception: Raises an Exception if an error occurred when building from registry. + """ + config_ = config.copy() # keep the original config untouched + assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}' + + mod_type = config_.pop('type') + assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}' + try: + obj = registry.get_module(mod_type)(**config_) + except Exception as e: + print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) + raise e + + return obj + + +def build_gradient_handler(config, model, optimizer): + """Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, + `model` and `optimizer`. + + Args: + config (dict or :class:`colossalai.context.Config`): A python dict or + a :class:`colossalai.context.Config` object containing information + used in the construction of the ``GRADIENT_HANDLER``. + model (:class:`nn.Module`): A model containing parameters for the gradient handler + optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler + + Returns: + An object of :class:`colossalai.engine.BaseGradientHandler` + """ + config_ = config.copy() + config_['model'] = model + config_['optimizer'] = optimizer + return build_from_registry(config_, GRADIENT_HANDLER) diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..658e35e4c72e77f7b3161bf86b0c9600a80562e5 --- /dev/null +++ b/colossalai/cli/__init__.py @@ -0,0 +1,3 @@ +from .cli import cli + +__all__ = ['cli'] diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c020d33b63bc38ee7631af0f46fcd5f30aad3c55 --- /dev/null +++ b/colossalai/cli/benchmark/__init__.py @@ -0,0 +1,27 @@ +import click + +from .utils import * +from .benchmark import run_benchmark +from colossalai.context import Config + +__all__ = ['benchmark'] + + +@click.command() +@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.") +@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.") +@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.") +@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.") +@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.") +@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.") +@click.option("-l", "--layers", type=int, default=2) +@click.option("-m", + "--model", + type=click.Choice(['mlp'], case_sensitive=False), + default='mlp', + help="Select the model to benchmark, currently only supports MLP") +def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int, + layers: int, model: str): + args_dict = locals() + args = Config(args_dict) + run_benchmark(args) diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..43632b150327972190d0e23e6c671432aa03cd12 --- /dev/null +++ b/colossalai/cli/benchmark/benchmark.py @@ -0,0 +1,103 @@ +import colossalai +import click +import torch.multiprocessing as mp + +from functools import partial +from typing import List, Dict + +from colossalai.context import Config +from colossalai.context.random import reset_seeds +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.utils import free_port, MultiTimer +from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data +from .models import MLP + + +def run_benchmark(args: Config) -> None: + """ + Run benchmarking with torch.multiprocessing. + """ + + # sanity checks + if args.gpus is None: + click.echo("Error: --num_gpus is not given") + exit() + if args.gpus <= 1: + click.echo("Warning: tensor parallel will be activated with at least 2 devices.") + + click.echo("=== Benchmarking Parameters ===") + for k, v in args.items(): + click.echo(f'{k}: {v}') + click.echo('') + + config_list = find_all_configs(args.gpus) + + avail_ports = [free_port() for _ in range(len(config_list))] + run_func = partial(run_dist_profiling, + world_size=args.gpus, + port_list=avail_ports, + config_list=config_list, + hyperparams=args) + mp.spawn(run_func, nprocs=args.gpus) + + +def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict], + hyperparams: Config) -> None: + """ + A function executed for profiling, this function should be spawn by torch.multiprocessing. + + Args: + rank (int): rank of the process + world_size (int): the number of processes + port_list (List[int]): a list of free ports for initializing distributed networks + config_list (List[Dict]): a list of configuration + hyperparams (Config): the hyperparameters given by the user + + """ + + # disable logging for clean output + disable_existing_loggers() + logger = get_dist_logger() + logger.set_level('WARNING') + + for config, port in zip(config_list, port_list): + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + timer = MultiTimer() + + # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size. + if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0: + click.echo( + "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size." + ) + continue + + if hyperparams.model == 'mlp': + model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers) + else: + if gpc.get_global_rank() == 0: + click.echo("Error: Invalid argument for --model") + exit() + + data_func = partial(get_batch_data, + dim=hyperparams.dimension, + batch_size=hyperparams.batch_size, + seq_length=hyperparams.seq_len, + mode=config.parallel.tensor.mode) + + fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model, + warmup_steps=hyperparams.warmup_steps, + profile_steps=hyperparams.profile_steps, + data_func=data_func, + timer=timer) + + gpc.destroy() + reset_seeds() + + if gpc.get_global_rank() == 0: + config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()]) + click.echo(f"=== {config_str} ===") + click.echo(f"Average forward time: {fwd_time}") + click.echo(f"Average backward time: {bwd_time}") + click.echo(f"Max allocated GPU memory: {max_allocated}") + click.echo(f"Max cached GPU memory: {max_cached}\n") diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py new file mode 100644 index 0000000000000000000000000000000000000000..38ea54188b8c31229f92c194aa157f381d9a1e36 --- /dev/null +++ b/colossalai/cli/benchmark/models.py @@ -0,0 +1,17 @@ +import torch +import colossalai.nn as col_nn + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(col_nn.Linear(dim, dim)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..825b795f21f680bcb2e2ea5eee4b328c2e1777db --- /dev/null +++ b/colossalai/cli/benchmark/utils.py @@ -0,0 +1,158 @@ +import math +import time +import torch + +from colossalai.utils import MultiTimer +from colossalai.context import ParallelMode, Config +from typing import List, Dict, Tuple, Callable + + +def get_time_stamp() -> int: + """ + Return the time stamp for profiling. + + Returns: + time_stamp (int): the time given by time.time() + """ + + torch.cuda.synchronize() + time_stamp = time.time() + return time_stamp + + +def get_memory_states() -> Tuple[float]: + """ + Return the memory statistics. + + Returns: + max_allocated (float): the allocated CUDA memory + max_cached (float): the cached CUDA memory + """ + + max_allocated = torch.cuda.max_memory_allocated() / (1024**3) + max_cached = torch.cuda.max_memory_reserved() / (1024**3) + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + return max_allocated, max_cached + + +def find_all_configs(device_cnt: int) -> List[Dict]: + """ + Find all possible configurations for tensor parallelism + + Args: + device_cnt (int): the number of devices + + Returns: + config_list (List[Dict]): a list of configurations + """ + + def _is_square(num): + # 2D parallel should be implemented with at least 2 devices. + if num <= 1: + return False + return math.floor(math.sqrt(num))**2 == num + + def _is_cube(num): + # 3D parallel should be implemented with at least 2 devices. + if num <= 1: + return False + return math.floor(num**(1. / 3.))**3 == num + + config_list = [] + + # add non-parallel config + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None))) + config_list.append(config) + + # add 1D config + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d'))) + config_list.append(config) + + # add 2D config only if device_cnt is a square + if _is_square(device_cnt): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d'))) + config_list.append(config) + + # check for 2.5D + # iterate over depth + for depth in range(1, device_cnt): + if device_cnt % depth == 0 and _is_square(device_cnt // depth): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth))) + config_list.append(config) + + # check for 3D if device_cnt is a cube + if _is_cube(device_cnt): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d'))) + config_list.append(config) + + config_list = [Config(cfg) for cfg in config_list] + return config_list + + +def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable, + timer: MultiTimer) -> Tuple[float]: + """ + Profile the forward and backward of a model + + Args: + model (torch.nn.Module): a PyTorch model + warmup_steps (int): the number of steps for warmup + profile_steps (int): the number of steps for profiling + data_func (Callable): a function to generate random data + timer (colossalai.utils.Multitimer): a timer instance for time recording + + Returns: + fwd_time (float): the average forward time taken by forward pass in second + bwd_time (float): the average backward time taken by forward pass in second + max_allocated (float): the maximum GPU memory allocated in GB + max_cached (float): the maximum GPU memory cached in GB + """ + + def _run_step(data): + timer.start('forward') + out = model(data) + timer.stop('forward', keep_in_history=True) + timer.start('backward') + out.mean().backward() + timer.stop('backward', keep_in_history=True) + + data_list = [data_func() for _ in range(warmup_steps)] + for data in data_list: + _run_step(data) + timer.reset('forward') + timer.reset('backward') + + for _ in range(profile_steps): + data = data_func() + _run_step(data) + + max_allocated, max_cached = get_memory_states() + fwd_time = timer.get_timer('forward').get_history_mean() + bwd_time = timer.get_timer('backward').get_history_mean() + return fwd_time, bwd_time, max_allocated, max_cached + + +def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor: + """ + Return a random data of shape (batch_size, seq_length, dim) for profiling. + + Args: + dim (int): hidden size + batch_size (int): the number of data samples + seq_length (int): the number of tokens + mode (ParallelMode): Colossal-AI ParallelMode enum + + Returns: + data (torch.Tensor): random data + """ + + if mode in ['2d', '2.5d']: + batch_size = batch_size // 2 + dim = dim // 2 + elif mode == '3d': + batch_size = batch_size // 4 + dim = dim // 2 + + data = torch.rand(batch_size, seq_length, dim).cuda() + return data diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a86b32bb6a181a7c75bfa6682c2f7514169d9a7f --- /dev/null +++ b/colossalai/cli/check/__init__.py @@ -0,0 +1,13 @@ +import click +from .check_installation import check_installation + +__all__ = ['check'] + + +@click.command(help="Check if Colossal-AI is correct based on the given option") +@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly") +def check(installation): + if installation: + check_installation() + return + click.echo("No option is given") diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py new file mode 100644 index 0000000000000000000000000000000000000000..a12b244027940d2c7306f00bcecf6dacd1494415 --- /dev/null +++ b/colossalai/cli/check/check_installation.py @@ -0,0 +1,95 @@ +import subprocess + +import click +import torch +from torch.utils.cpp_extension import CUDA_HOME + +import colossalai + + +def check_installation(): + cuda_ext_installed = _check_cuda_extension_installed() + cuda_version, torch_version, torch_cuda_version = _check_cuda_torch() + colossalai_verison, torch_version_required, cuda_version_required = _parse_colossalai_version() + + cuda_compatibility = _get_compatibility_string([cuda_version, torch_cuda_version, cuda_version_required]) + torch_compatibility = _get_compatibility_string([torch_version, torch_version_required]) + + click.echo(f'#### Installation Report ####\n') + click.echo(f"Colossal-AI version: {colossalai_verison}") + click.echo(f'----------------------------') + click.echo(f"PyTorch Version: {torch_version}") + click.echo(f"PyTorch Version required by Colossal-AI: {torch_version_required}") + click.echo(f'PyTorch version match: {torch_compatibility}') + click.echo(f'----------------------------') + click.echo(f"System CUDA Version: {cuda_version}") + click.echo(f"CUDA Version required by PyTorch: {torch_cuda_version}") + click.echo(f"CUDA Version required by Colossal-AI: {cuda_version_required}") + click.echo(f"CUDA Version Match: {cuda_compatibility}") + click.echo(f'----------------------------') + click.echo(f"CUDA Extension: {cuda_ext_installed}") + + +def _get_compatibility_string(versions): + + # split version into [major, minor, patch] + versions = [version.split('.') for version in versions] + + for version in versions: + if len(version) == 2: + # x means unknown + version.append('x') + + for idx, version_values in enumerate(zip(*versions)): + equal = len(set(version_values)) == 1 + + if idx in [0, 1] and not equal: + # if the major/minor versions do not match + # return a cross + return 'x' + elif idx == 1: + # if the minor versions match + # return a tick + return u'\u2713' + else: + continue + + +def _parse_colossalai_version(): + colossalai_verison = colossalai.__version__.split('+')[0] + torch_version_required = colossalai.__version__.split('torch')[1].split('cu')[0] + cuda_version_required = colossalai.__version__.split('cu')[1] + return colossalai_verison, torch_version_required, cuda_version_required + + +def _check_cuda_extension_installed(): + try: + import colossalai._C.fused_optim + is_cuda_extension_installed = u'\u2713' + except ImportError: + is_cuda_extension_installed = 'x' + return is_cuda_extension_installed + + +def _check_cuda_torch(): + # get cuda version + if CUDA_HOME is None: + cuda_version = 'N/A (CUDA_HOME is not set)' + else: + raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + + # get torch version + torch_version = torch.__version__.split('+')[0] + + # get cuda version in pytorch build + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + + return cuda_version, torch_version, torch_cuda_version diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5b9ae6343f1d7585055a0e8894cf2128305b59 --- /dev/null +++ b/colossalai/cli/cli.py @@ -0,0 +1,24 @@ +import click +from .launcher import run +from .check import check +from .benchmark import benchmark + + +class Arguments(): + + def __init__(self, arg_dict): + for k, v in arg_dict.items(): + self.__dict__[k] = v + + +@click.group() +def cli(): + pass + + +cli.add_command(run) +cli.add_command(check) +cli.add_command(benchmark) + +if __name__ == '__main__': + cli() diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ada68b4b68fd5743330f04fce3c7afe10e65bdc --- /dev/null +++ b/colossalai/cli/launcher/__init__.py @@ -0,0 +1,85 @@ +import click +from .run import launch_multi_processes +from colossalai.context import Config + + +@click.command(help="Launch distributed training on a single node or multiple nodes", + context_settings=dict(ignore_unknown_options=True)) +@click.option("-H", + "-host", + "--host", + type=str, + default=None, + help="the list of hostnames to launch in the format ,") +@click.option( + "--hostfile", + type=str, + default=None, + help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") +@click.option("--include", + type=str, + default=None, + help="Specify computing devices to use during execution. String format is ,," + " only effective when used with --hostfile.") +@click.option( + "--exclude", + type=str, + default=None, + help= + "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ," + " only effective when used with --hostfile.") +@click.option("--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use, only effective when used with --hostfile.") +@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") +@click.option("--master_port", + type=int, + default=29500, + help="(optional) Port used by PyTorch distributed for communication during distributed training.") +@click.option("--master_addr", + type=str, + default="127.0.0.1", + help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") +@click.option( + "--extra_launch_args", + type=str, + default=None, + help= + "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " + "This will be converted to --arg1=1 --arg2=2 during execution") +@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") +@click.argument("user_script", type=str) +@click.argument('user_args', nargs=-1) +def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, + master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: + """ + To launch multiple processes on a single node or multiple nodes via command line. + + Usage:: + # run with 4 GPUs on the current node use default port 29500 + colossalai run --nprocs_per_node 4 train.py + + # run with 2 GPUs on the current node at port 29550 + colossalai run --nprocs_per_node 4 --master_port 29550 train.py + + # run on two nodes + colossalai run --host , --master_addr host1 --nprocs_per_node 4 train.py + + # run with hostfile + colossalai run --hostfile --master_addr --nprocs_per_node 4 train.py + + # run with hostfile with only included hosts + colossalai run --hostfile --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py + + # run with hostfile excluding the hosts selected + colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py + """ + if not user_script.endswith('.py'): + click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') + exit() + + args_dict = locals() + args = Config(args_dict) + args.user_args = list(args.user_args) + launch_multi_processes(args) diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0830c5880d8244a495dd4489db8a7ec3f1156e --- /dev/null +++ b/colossalai/cli/launcher/hostinfo.py @@ -0,0 +1,122 @@ +from typing import List +import socket + + +class HostInfo: + """ + A data class to store host connection-related data. + + Args: + hostname (str): name or IP address of the host + port (str): the port for ssh connection + """ + + def __init__( + self, + hostname: str, + port: str = None, + ): + self.hostname = hostname + self.port = port + self.is_local_host = HostInfo.is_host_localhost(hostname, port) + + @staticmethod + def is_host_localhost(hostname: str, port: str = None) -> None: + """ + Check if the host refers to the local machine. + + Args: + hostname (str): name or IP address of the host + port (str): the port for ssh connection + + Returns: + bool: True if it is local, False otherwise + """ + + if port is None: + port = 22 # no port specified, lets just use the ssh port + hostname = socket.getfqdn(hostname) + if hostname in ("localhost", "127.0.0.1", "0.0.0.0"): + return True + localhost = socket.gethostname() + localaddrs = socket.getaddrinfo(localhost, port) + targetaddrs = socket.getaddrinfo(hostname, port) + for (family, socktype, proto, canonname, sockaddr) in localaddrs: + for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs: + if rsockaddr[0] == sockaddr[0]: + return True + return False + + def __str__(self): + return f'hostname: {self.hostname}, port: {self.port}' + + def __repr__(self): + return self.__str__() + + +class HostInfoList: + """ + A data class to store a list of HostInfo objects. + """ + + def __init__(self): + self.hostinfo_list = [] + + def append(self, hostinfo: HostInfo) -> None: + """ + Add an HostInfo object to the list. + + Args: + hostinfo (HostInfo): host information + """ + + self.hostinfo_list.append(hostinfo) + + def remove(self, hostname: str) -> None: + """ + Add an HostInfo object to the list. + + Args: + hostname (str): the name of the host + """ + + hostinfo = self.get_hostinfo(hostname) + self.hostinfo_list.remove(hostinfo) + + def get_hostinfo(self, hostname: str) -> HostInfo: + """ + Return the HostInfo object which matches with the hostname. + + Args: + hostname (str): the name of the host + + Returns: + hostinfo (HostInfo): the HostInfo object which matches with the hostname + """ + + for hostinfo in self.hostinfo_list: + if hostinfo.hostname == hostname: + return hostinfo + + raise Exception(f"Hostname {hostname} is not found") + + def has(self, hostname: str) -> bool: + """ + Check if the hostname has been added. + + Args: + hostname (str): the name of the host + + Returns: + bool: True if added, False otherwise + """ + for hostinfo in self.hostinfo_list: + if hostinfo.hostname == hostname: + return True + return False + + def __iter__(self): + return iter(self.hostinfo_list) + + def __len__(self): + return len(self.hostinfo_list) diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..c45ad5e5a082b6d30c6ccbc16bd7803cdd54c993 --- /dev/null +++ b/colossalai/cli/launcher/multinode_runner.py @@ -0,0 +1,119 @@ +import fabric +from .hostinfo import HostInfo, HostInfoList +from multiprocessing import Pipe, Process +from multiprocessing import connection as mp_connection +import click + + +def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, + send_conn: mp_connection.Connection, env: dict) -> None: + """ + Use fabric connection to execute command on local or remote hosts. + + Args: + hostinfo (HostInfo): host information + workdir (str): the directory to execute the command + recv_conn (multiprocessing.connection.Connection): receive messages from the master sender + send_conn (multiprocessing.connection.Connection): send messages to the master receiver + env (dict): a dictionary for environment variables + """ + + fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) + finish = False + env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) + + # keep listening until exit + while not finish: + # receive cmd + cmds = recv_conn.recv() + + if cmds == 'exit': + # exit from the loop + finish = True + break + else: + # execute the commands + try: + # cd to execute directory + with fab_conn.cd(workdir): + # propagate the runtime environment + with fab_conn.prefix(f"export {env_msg}"): + if hostinfo.is_local_host: + # execute on the local machine + fab_conn.local(cmds, hide=False) + else: + # execute on the remote machine + fab_conn.run(cmds, hide=False) + send_conn.send('success') + except: + click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}") + send_conn.send('failure') + + # shutdown + send_conn.send("finish") + fab_conn.close() + + +class MultiNodeRunner: + """ + A runner to execute commands on an array of machines. This runner + is inspired by Nezha (https://github.com/zhuzilin/NeZha). + """ + + def __init__(self): + self.processes = {} + self.master_send_conns = {} + self.master_recv_conns = {} + + def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None: + """ + Establish connections to a list of hosts + + Args: + host_info_list (HostInfoList): a list of HostInfo objects + workdir (str): the directory where command is executed + env (dict): environment variables to propagate to hosts + """ + for hostinfo in host_info_list: + master_send_conn, worker_recv_conn = Pipe() + master_recv_conn, worker_send_conn = Pipe() + p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env)) + p.start() + self.processes[hostinfo.hostname] = p + self.master_recv_conns[hostinfo.hostname] = master_recv_conn + self.master_send_conns[hostinfo.hostname] = master_send_conn + + def send(self, hostinfo: HostInfo, cmd: str) -> None: + """ + Send a command to a local/remote host. + + Args: + hostinfo (HostInfo): host information + cmd (str): the command to execute + """ + + assert hostinfo.hostname in self.master_send_conns, \ + f'{hostinfo} is not found in the current connections' + conn = self.master_send_conns[hostinfo.hostname] + conn.send(cmd) + + def stop_all(self) -> None: + """ + Stop connections to all hosts. + """ + + for hostname, conn in self.master_send_conns.items(): + conn.send('exit') + + def recv_from_all(self) -> dict: + """ + Receive messages from all hosts + + Returns: + msg_from_node (dict): a dictionry which contains messages from each node + """ + + msg_from_node = dict() + for hostname, conn in self.master_recv_conns.items(): + msg_from_node[hostname] = conn.recv() + return msg_from_node diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py new file mode 100644 index 0000000000000000000000000000000000000000..e078a57c15c916f9d291641564e9f9e271753938 --- /dev/null +++ b/colossalai/cli/launcher/run.py @@ -0,0 +1,281 @@ +import click +import sys +import os +import torch +from colossalai.context import Config +from .multinode_runner import MultiNodeRunner +from .hostinfo import HostInfo, HostInfoList +from typing import List +from packaging import version + +# Constants that define our syntax +NODE_SEP = ',' + + +def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: + """ + Parse the hostfile to obtain a list of hosts. + + A hostfile should look like: + worker-0 + worker-1 + worker-2 + ... + + Args: + hostfile_path (str): the path to the hostfile + ssh_port (int): the port to connect to the host + """ + + if not os.path.isfile(hostfile_path): + click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") + exit() + + with open(hostfile_path, 'r') as fd: + device_pool = HostInfoList() + + for line in fd.readlines(): + line = line.strip() + if line == '': + # skip empty lines + continue + + # build the HostInfo object + hostname = line.strip() + hostinfo = HostInfo(hostname=hostname, port=ssh_port) + + if device_pool.has(hostname): + click.echo(f"Error: found duplicate host {hostname} in the hostfile") + exit() + + device_pool.append(hostinfo) + return device_pool + + +def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: + '''Parse an inclusion or exclusion string and filter a hostfile dictionary. + + Examples: + include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. + exclude_str="worker-1" will use all available devices except worker-1. + + Args: + device_pool (HostInfoList): a list of HostInfo objects + include_str (str): --include option passed by user, default None + exclude_str (str): --exclude option passed by user, default None + + Returns: + filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion + ''' + + # Ensure include/exclude are mutually exclusive + if include_str and exclude_str: + click.echo("--include and --exclude are mutually exclusive, only one can be used") + exit() + + # no-op + if include_str is None and exclude_str is None: + return device_pool + + # Either build from scratch or remove items + if include_str: + parse_str = include_str + filtered_hosts = HostInfoList() + elif exclude_str: + parse_str = exclude_str + filtered_hosts = device_pool + + # foreach node in the list + for node_config in parse_str.split(NODE_SEP): + hostname = node_config + hostinfo = device_pool.get_hostinfo(hostname) + # sanity check hostname + if not device_pool.has(hostname): + click.echo(f"Error: Hostname '{hostname}' not found in hostfile") + exit() + + if include_str: + filtered_hosts.append(hostinfo) + elif exclude_str: + filtered_hosts.remove(hostname) + + return filtered_hosts + + +def get_launch_command( + master_addr: str, + master_port: int, + nproc_per_node: int, + user_script: str, + user_args: List[str], + node_rank: int, + num_nodes: int, + extra_launch_args: str = None, +) -> str: + """ + Generate a command for distributed training. + + Args: + master_addr (str): the host of the master node + master_port (str): the port of the master node + nproc_per_node (str): the number of processes to launch on each node + user_script (str): the user Python file + user_args (str): the arguments for the user script + node_rank (int): the unique ID for the node + num_nodes (int): the number of nodes to execute jobs + + Returns: + cmd (str): the command the start distributed training + """ + + def _arg_dict_to_list(arg_dict): + ret = [] + + for k, v in arg_dict.items(): + if v: + ret.append(f'--{k}={v}') + else: + ret.append(f'--{k}') + return ret + + if extra_launch_args: + extra_launch_args_dict = dict() + for arg in extra_launch_args.split(','): + if '=' in arg: + k, v = arg.split('=') + extra_launch_args_dict[k] = v + else: + extra_launch_args_dict[arg] = None + extra_launch_args = extra_launch_args_dict + else: + extra_launch_args = dict() + + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + + if torch_version.minor < 9: + cmd = [ + sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", + f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", + f"--node_rank={node_rank}" + ] + else: + # extra launch args for torch distributed launcher with torch >= 1.9 + default_torchrun_rdzv_args = dict(rdzv_backend="c10d", + rdzv_endpoint=f"{master_addr}:{master_port}", + rdzv_id="colossalai-default-job") + + # update rdzv arguments + for key in default_torchrun_rdzv_args.keys(): + if key in extra_launch_args: + value = extra_launch_args.pop(key) + default_torchrun_rdzv_args[key] = value + + if torch_version.minor < 10: + cmd = [ + sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + ] + else: + cmd = [ + "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + ] + cmd += _arg_dict_to_list(default_torchrun_rdzv_args) + + cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args + cmd = ' '.join(cmd) + return cmd + + +def launch_multi_processes(args: Config) -> None: + """ + Launch multiple processes on a single node or multiple nodes. + + The overall logic can be summarized as the pseudo code below: + + if hostfile given: + hostinfo = parse_hostfile(hostfile) + hostinfo = include_or_exclude_hosts(hostinfo) + launch_on_multi_nodes(hostinfo) + elif hosts given: + hostinfo = parse_hosts(hosts) + launch_on_multi_nodes(hostinfo) + else: + launch_on_current_node() + + Args: + args (Config): the arguments taken from command line + + """ + assert isinstance(args, Config) + + if args.nproc_per_node is None: + click.echo("--nproc_per_node did not receive any value") + exit() + + # cannot accept hosts and hostfile at the same time + if args.host and args.hostfile: + click.echo("Error: hostfile and hosts are mutually exclusive, only one is required") + + # check if hostfile is given + if args.hostfile: + device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port) + active_device_pool = parse_device_filter(device_pool, args.include, args.exclude) + + if args.num_nodes > 0: + # only keep the first num_nodes to execute jobs + updated_active_device_pool = HostInfoList() + for count, hostinfo in enumerate(active_device_pool): + if args.num_nodes == count: + break + updated_active_device_pool.append(hostinfo) + active_device_pool = updated_active_device_pool + else: + active_device_pool = None + + env = os.environ.copy() + + # use hosts if hostfile is not given + if args.host and active_device_pool is None: + active_device_pool = HostInfoList() + host_list = args.host.strip().split(NODE_SEP) + for hostname in host_list: + hostinfo = HostInfo(hostname=hostname, port=args.ssh_port) + active_device_pool.append(hostinfo) + + if not active_device_pool: + # run on local node if not hosts or hostfile is given + # add local node to host info list + active_device_pool = HostInfoList() + localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) + active_device_pool.append(localhost_info) + + # launch distributed processes + runner = MultiNodeRunner() + curr_path = os.path.abspath('.') + + # collect current path env + env = dict() + for k, v in os.environ.items(): + # do not support multi-line env var + if v and '\n' not in v: + env[k] = v + + # establish remote connection + runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) + + # execute distributed launching command + for node_id, hostinfo in enumerate(active_device_pool): + cmd = get_launch_command(master_addr=args.master_addr, + master_port=args.master_port, + nproc_per_node=args.nproc_per_node, + user_script=args.user_script, + user_args=args.user_args, + node_rank=node_id, + num_nodes=len(active_device_pool), + extra_launch_args=args.extra_launch_args) + runner.send(hostinfo=hostinfo, cmd=cmd) + + runner.recv_from_all() + runner.stop_all() + runner.recv_from_all() diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..220481b7af15bcada443ed7c9f8c91350a5f76b1 --- /dev/null +++ b/colossalai/communication/__init__.py @@ -0,0 +1,26 @@ +from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce +from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, + send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, + recv_forward, recv_backward) +from .ring import ring_forward +from .utils import send_obj_meta, recv_obj_meta + +__all__ = [ + 'all_gather', + 'reduce_scatter', + 'all_reduce', + 'broadcast', + 'reduce', + 'send_forward', + 'send_forward_recv_forward', + 'send_forward_backward_recv_forward_backward', + 'send_backward', + 'send_backward_recv_backward', + 'send_backward_recv_forward', + 'send_forward_recv_backward', + 'recv_backward', + 'recv_forward', + 'ring_forward', + 'send_obj_meta', + 'recv_obj_meta', +] diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9e9927c7d9d6ecf83949c7f1d7c06bde57aaa9 --- /dev/null +++ b/colossalai/communication/collective.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp +from torch import Tensor + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: + r"""Gathers all tensors from the parallel group and concatenates them in a + specific dimension. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered. + dim (int): The dimension concatenating in. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of all-together only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + shape = list(tensor.shape) + shape[0], shape[dim] = shape[dim], shape[0] + shape[0] *= depth + out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) + temp = list(torch.chunk(out, depth, dim=0)) + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.all_gather(tensor_list=temp, + tensor=tensor.transpose(0, dim).contiguous(), + group=group, + async_op=async_op) + out = torch.transpose(out, 0, dim) + if async_op: + return out, work + else: + return out + + +def reduce_scatter(tensor: Tensor, + dim: int, + parallel_mode: ParallelMode, + op: ReduceOp = ReduceOp.SUM, + async_op: bool = False) -> Tensor: + r"""Reduces all tensors then scatters it in a specific dimension to all + members in the parallel group. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered. + dim (int): The dimension concatenating in. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + op (torch.distributed.ReduceOp, optional): The type of reduce operation, + should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. + More details about ReduceOp please refer to + `ReduceOp `_. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of reduce_scatter only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) + out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def all_reduce(tensor: Tensor, + parallel_mode: ParallelMode, + op: ReduceOp = ReduceOp.SUM, + async_op: bool = False) -> Tensor: + r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be all-reduced. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + op (torch.distributed.ReduceOp, optional): The type of reduce operation, + should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. + More details about ReduceOp please refer to + `ReduceOp `_. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of all-gather only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + out = tensor.contiguous() + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.all_reduce(out, op=op, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): + r"""Broadcast tensors to whole parallel group. Tensor must have the same + number of elements in all processes participating in the collective. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be broadcast. + src (int): Source rank. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The tensor need to be broadcast only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + out = tensor.contiguous() + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.broadcast(out, src=src, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): + r"""Reduce tensors across whole parallel group. Only the process with + rank ``dst`` is going to receive the final result. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be reduced. + dst (int): Destination rank. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of reduce only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + out = tensor.contiguous() + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: + r"""Modified from `torch.distributed.scatter_object_list ` to fix issues + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + + if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): + raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") + + # set tensor device to cuda if backend is nccl + device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") + + my_rank = dist.get_rank() # use global rank + if my_rank == src: + tensor_list, tensor_sizes = zip( + *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) + tensor_list = list(map(lambda x: x.to(device), tensor_list)) + tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + if my_rank == src: + max_tensor_size = max(tensor_sizes) + for tensor in tensor_list: + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long).to(device) + + dist.broadcast(max_tensor_size, src=src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device) + dist.scatter( + output_tensor, + scatter_list=None if my_rank != src else tensor_list, + src=src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device) + dist.scatter( + obj_tensor_size, + scatter_list=None if my_rank != src else tensor_sizes, + src=src, + group=group, + ) + + output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu() + # Deserialize back to object + scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size) diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd4d0d6608df586bd809699606a5fa6918e6066 --- /dev/null +++ b/colossalai/communication/p2p.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Tuple, Union +import torch +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from functools import reduce +import operator +from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: + """get the exact tensor shape when communicating and return whether the tensor is a chunk + + Args: + tensor_shape (:class:`torch.Size`): shape of tensor + chunk_tensor (bool, optional): whether to chunk tensor, defaults to False + + Returns: + Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor + """ + if chunk_tensor: + tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) + tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) + if tensor_chunk_shape % tensor_parallel_world_size == 0: + tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size + else: + tensor_chunk_shape = tensor_shape + chunk_tensor = False + else: + tensor_chunk_shape = tensor_shape + return tensor_chunk_shape, chunk_tensor + + +def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): + if isinstance(recv_shapes, torch.Size): + recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) + buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + return buffer_recv, recv_split + buffer_recv = [] + for recv_shape in recv_shapes: + recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) + tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv.append(tensor_recv) + return buffer_recv, recv_split + + +def process_object_to_send(object_send, scatter_gather_tensors): + if isinstance(object_send, torch.Tensor): + send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1] + if send_split: + object_send = split_tensor_into_1d_equal_chunks(object_send) + return object_send + + object_send_list = [] + for tensor_send in object_send: + send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1] + if send_split: + object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send)) + else: + object_send_list.append(tensor_send) + object_send = tuple(object_send_list) + + return object_send + + +def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): + if isinstance(obj, torch.Tensor): + op_to_add = dist.P2POp(comm_op, obj, comm_rank) + ops_queue.append(op_to_add) + else: + for tensor_to_comm in obj: + op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) + ops_queue.append(op_to_add) + + +def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + """ + Adapted from megatron.p2p_communication. + Communicate tensors between stages. Used as helper method in other + communication methods that are used in pipeline schedule. + Takes the following arguments: + object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if + set to None). + object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if + set to None). + recv_prev (bool): boolean for whether tensor should be received from + previous rank. + recv_next (bool): boolean for whether tensor should be received from + next rank. + recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defualts to None. + recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defualts to None. + prev_rank (int): the rank of the previous pipeline stage, defualts to None, + next_rank (int): the rank of the next pipeline stage, defualts to None, + dtype (torch.dtype): data type of intermediate buffers, defaults to None + scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False + + Returns: + Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next + """ + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + if recv_prev: + assert recv_prev_shape is not None + tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype, + scatter_gather_tensors) + + if recv_next: + assert recv_next_shape is not None + tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype, + scatter_gather_tensors) + + if object_send_prev is not None or recv_prev: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + if object_send_next is not None or recv_next: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + if object_send_prev is not None: + object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors) + + if object_send_next is not None: + object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) + + ops = [] + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + if recv_prev and recv_prev_split: + if isinstance(tensor_recv_prev, torch.Tensor): + tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() + else: + for index in range(len(tensor_recv_prev)): + tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( + recv_prev_shape[index]).requires_grad_() + + if recv_next and recv_next_split: + if isinstance(tensor_recv_next, torch.Tensor): + tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() + else: + for index in range(len(tensor_recv_next)): + tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( + recv_next_shape[index]).requires_grad_() + + return tensor_recv_prev, tensor_recv_next + + +def recv_forward(input_tensor_shape, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + input_tensor, _ = _communicate(recv_prev=True, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor + + +def recv_backward(output_grad_shape, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate(recv_next=True, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return output_tensor_grad + + +def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not gpc.is_pipeline_last_stage(): + _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) + + +def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not gpc.is_pipeline_first_stage(): + _communicate(object_send_prev=input_tensor_grad, + prev_rank=prev_rank, + scatter_gather_tensors=scatter_gather_tensors) + + +def send_forward_recv_backward(output_tensor, + output_grad_shape, + recv_next=True, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the input tensor to the + next stage in pipeline, while receives the gradient tensor from the + next stage in pipeline as the input gradient tensor of this stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate(object_send_next=output_tensor, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return output_tensor_grad + + +def send_backward_recv_forward(input_tensor_grad, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the gradient tensor to the + previous stage in pipeline, while receives the output tensor from the + previous stage in pipeline as the input of this stage. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + input_tensor, _ = _communicate(object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor + + +def send_forward_recv_forward(output_tensor, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the input tensor to the + next stage in pipeline, while receives the output tensor from the + previous stage in pipeline as the input of this stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. + """ + input_tensor, _ = _communicate(object_send_next=output_tensor, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor + + +def send_backward_recv_backward(input_tensor_grad, + output_grad_shape, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the gradient tensor to the + previous stage in pipeline, while receives the gradient tensor from the + next member in pipeline as the input of this stage. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. + """ + _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return output_tensor_grad + + +def send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + input_tensor_shape, + output_grad_shape, + recv_prev=True, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + """Batched communication operation. Sends the input tensor to the next stage in pipeline and + the gradient tensor to the previous stage, while receives the input gradient tensor from the + next stage and the input tensor from the previous stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next. + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next. + + Returns: + Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) + """ + input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor, + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + recv_prev_shape=input_tensor_shape, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor, output_tensor_grad diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0b575e7dba77c02afe351a399dcb567701f569d8 --- /dev/null +++ b/colossalai/communication/p2p_v2.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Tuple, Union, Any +import pickle +import io + +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d as c10d +from torch.distributed import ProcessGroupNCCL + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +TensorShape = Union[torch.Size, List[int], Tuple[int]] +_pg_manager = {} +_unpickler = pickle.Unpickler + + +def init_process_group(): + """intialise process group by dist.new_group in the adjacent stages + + Args: + None + + Returns: + None + """ + world_size = gpc.get_world_size(ParallelMode.PIPELINE) + for i in range(world_size - 1): + _pg_manager[(i, i + 1)] = dist.new_group([i, i + 1]) + + +def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGroupNCCL: + """get the group handle of two given ranks + + Args: + first_rank (int): first rank in the pair + second_rank (int): second rank in the pair + + Returns: + :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks + """ + if len(_pg_manager) == 0: + init_process_group() + if first_rank > second_rank: + first_rank, second_rank = second_rank, first_rank + pair_key = (first_rank, second_rank) + return _pg_manager[pair_key] + + +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling + + Args: + tensor (:class:`torch.tensor`): tensor to be unpickled + tensor_size (:class:`torch.Size`): Size of the real info in bytes + + Returns: + Any: object after unpickled + """ + buf = tensor.numpy().tobytes()[:tensor_size] + if b'cuda' in buf: + buf_array = bytearray(buf) + device_index = torch.cuda.current_device() + buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf = bytes(buf_array) + + io_bytes = io.BytesIO(buf) + byte_pickler = _unpickler(io_bytes) + unpickle = byte_pickler.load() + + return unpickle + + +def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None): + """This is a modified version of the broadcast_object_list in torch.distribution + The only difference is that object will be move to correct device after unpickled. + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + be updated with data sent from rank src. + + Args: + object_list (List[Any]): list of object to broadcast + src (int): source rank to broadcast + dst (int): dst rank to broadcast + device (:class:`torch.device`): device to do broadcast. current device in default + + """ + group = _acquire_pair_group_handle(src, dst) + + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("broadcast_object_list") + return + + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # Serialize object_list elements to tensors on src rank. + if local_rank == src: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + is_nccl_backend = c10d._check_for_nccl_backend(group) + current_device = None + + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) + + # Concatenate and broadcast serialized object tensors + if local_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.broadcast(object_tensor, src=src, group=group, async_op=False) + + # Deserialize objects using their stored sizes. + offset = 0 + + if local_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + # unpickle + unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) + + # unconsistence in device + if isinstance(unpickle_object, + torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + unpickle_object = unpickle_object.cuda() + + object_list[i] = unpickle_object + + +def _send_object(object: Any, dst: int) -> None: + """send anything to dst rank + Args: + object (Any): object needed to be sent + dst (int): rank of the destination + + Returns: + None + """ + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # handler = _acquire_pair_group_handle(local_rank, dst) + + # transform to list if not + if isinstance(object, torch.Tensor): + object = [object] + + # broadcast length first + # TODO : more elegant ? P.S. reduce a _broadcast_object_list + _broadcast_object_list([len(object)], local_rank, dst) + # then broadcast safely + _broadcast_object_list(object, local_rank, dst) + + +def _recv_object(src: int) -> Any: + """recv anything from src + + Args: + src (int): source rank of data. local rank will receive data from src rank. + + Returns: + Any: Object received from src. + """ + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # handler = _acquire_pair_group_handle(local_rank, src) + # recv length first + length = [0] + _broadcast_object_list(length, src, local_rank) + + # then create recv buff from length[0] and broadcast + object = [None] * length[0] + _broadcast_object_list(object, src, local_rank) + + if length[0] == 1: + object = object[0] + + return object + + +def recv_forward(prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + input_tensor = _recv_object(prev_rank) + + return input_tensor + + +def recv_backward(next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradident tensor list. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + output_tensor_grad = _recv_object(next_rank) + + return output_tensor_grad + + +def send_forward(output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_object Any: Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not gpc.is_pipeline_last_stage(): + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + _send_object(output_object, next_rank) + + +def send_backward(input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not gpc.is_pipeline_first_stage(): + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + _send_object(input_object, prev_rank) diff --git a/colossalai/communication/ring.py b/colossalai/communication/ring.py new file mode 100644 index 0000000000000000000000000000000000000000..aece7574b7c41cac3b16cd5891b1e26d0ede9c36 --- /dev/null +++ b/colossalai/communication/ring.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device, synchronize + + +def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: + """Sends a tensor to the next member and receives a tensor from the previous member. + This function returns the received tensor from the previous member. + + Args: + tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member + parallel_mode (ParallelMode): Parallel group mode used in this communication + + Returns: + :class:`torch.Tensor`: The tensor received from the previous. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + buffer_shape = tensor_send_next.size() + + ops = [] + current_rank = gpc.get_global_rank() + + tensor_recv_prev = torch.empty(buffer_shape, + requires_grad=True, + device=get_current_device(), + dtype=tensor_send_next.dtype) + + # send to next rank + send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, + gpc.get_next_global_rank(parallel_mode)) + ops.append(send_next_op) + + # receive from prev rank + recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, + gpc.get_prev_global_rank(parallel_mode)) + ops.append(recv_prev_op) + + if current_rank % 2 == 0: + ops = ops[::-1] + + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + synchronize() + + return tensor_recv_prev diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9eceea847dd3d6cb036e87e369529dcbe0db41 --- /dev/null +++ b/colossalai/communication/utils.py @@ -0,0 +1,126 @@ +import torch +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from typing import Union, List, Tuple + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def send_meta_helper(obj, next_rank, tensor_kwargs): + send_shape = torch.tensor(obj.size(), **tensor_kwargs) + send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) + + +def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: + """Sends obj meta information before sending a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be sent before communications. This function + synchronizes with :func:`recv_obj_meta`. + + Args: + obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. + need_meta (bool, optional): If False, meta information won't be sent. + next_rank (int): The rank of the next member in pipeline parallel group. + + Returns: + bool: False + """ + if need_meta: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + if isinstance(obj, torch.Tensor): + send_obj_nums = torch.tensor(1, **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + send_meta_helper(obj, next_rank, tensor_kwargs) + else: + send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + for tensor_to_send in obj: + send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) + + return False + + +def recv_meta_helper(prev_rank, tensor_kwargs): + recv_ndims = torch.empty((), **tensor_kwargs) + dist.recv(recv_ndims, prev_rank) + recv_shape = torch.empty(recv_ndims, **tensor_kwargs) + dist.recv(recv_shape, prev_rank) + return recv_shape + + +def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: + """Receives obj meta information before receiving a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be received before communications. This function + synchronizes with :func:`send_obj_meta`. + + Args: + obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. + prev_rank (int): The rank of the source of the obj. + + Returns: + Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. + """ + if obj_shape is None: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + recv_obj_nums = torch.empty((), **tensor_kwargs) + dist.recv(recv_obj_nums, prev_rank) + if recv_obj_nums.item() == 1: + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape = torch.Size(recv_shape) + else: + obj_shape = [] + for i in range(recv_obj_nums.item()): + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape.append(torch.Size(recv_shape)) + + return obj_shape + + +def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: + """Break a tensor into equal 1D chunks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be split before communication. + new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. + + Returns: + :class:`torch.Tensor`: The split tensor + """ + partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) + start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Opposite of above function, gather values from model parallel ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. + Returns: + :class:`torch.Tensor`: The gathered tensor. + """ + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] + dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + return gathered diff --git a/colossalai/constants.py b/colossalai/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf9085f9fbb63ea18d2712f99c08f24b539245d --- /dev/null +++ b/colossalai/constants.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] +TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' + +# initializer +INITIALIZER_MAPPING = { + 'data': 'Initializer_Data', + 'tensor': 'Initializer_Tensor', + 'pipeline': 'Initializer_Pipeline', + 'embedding': 'Initializer_Embedding', + '1d': 'Initializer_1D', + '2d': 'Initializer_2D', + '2.5d': 'Initializer_2p5D', + '3d': 'Initializer_3D', + 'sequence': 'Initializer_Sequence', + 'model': 'Initializer_Model', + 'moe': 'Initializer_Moe' +} + +# 3D parallelism groups +INPUT_GROUP_3D = 'input_group_3d' +WEIGHT_GROUP_3D = 'weight_group_3d' +OUTPUT_GROUP_3D = 'output_group_3d' +INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' +OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' + +# Attributes of tensor parallel parameters +IS_TENSOR_PARALLEL = 'is_tensor_parallel' +NUM_PARTITIONS = 'num_partitions' +TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50178b5fa850777f8455798cc6ab9d7254c5a9fe --- /dev/null +++ b/colossalai/context/__init__.py @@ -0,0 +1,6 @@ +from .config import Config, ConfigException +from .parallel_context import ParallelContext +from .parallel_mode import ParallelMode +from .moe_context import MOE_CONTEXT +from .process_group_initializer import * +from .random import * diff --git a/colossalai/context/config.py b/colossalai/context/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8903707708df96eac7a0a70343e37e984e6fabed --- /dev/null +++ b/colossalai/context/config.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +import sys +from importlib.machinery import SourceFileLoader +from pathlib import Path +from colossalai.logging import get_dist_logger + + +class Config(dict): + """This is a wrapper class for dict objects so that values of which can be + accessed as attributes. + + Args: + config (dict): The dict object to be wrapped. + """ + + def __init__(self, config: dict = None): + if config is not None: + for k, v in config.items(): + self._add_item(k, v) + + def __missing__(self, key): + raise KeyError(key) + + def __getattr__(self, key): + try: + value = super(Config, self).__getitem__(key) + return value + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + super(Config, self).__setitem__(key, value) + + def _add_item(self, key, value): + if isinstance(value, dict): + self.__setattr__(key, Config(value)) + else: + self.__setattr__(key, value) + + def update(self, config): + assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' + for k, v in config.items(): + self._add_item(k, v) + return self + + @staticmethod + def from_file(filename: str): + """Reads a python file and constructs a corresponding :class:`Config` object. + + Args: + filename (str): Name of the file to construct the return object. + + Returns: + :class:`Config`: A :class:`Config` object constructed with information in the file. + + Raises: + AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file + """ + + # check config path + if isinstance(filename, str): + filepath = Path(filename).absolute() + elif isinstance(filename, Path): + filepath = filename.absolute() + + assert filepath.exists(), f'{filename} is not found, please check your configuration path' + + # check extension + extension = filepath.suffix + assert extension == '.py', 'only .py files are supported' + + # import the config as module + remove_path = False + if filepath.parent not in sys.path: + sys.path.insert(0, (filepath)) + remove_path = True + + module_name = filepath.stem + source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) + module = source_file.load_module() + + # load into config + config = Config() + + for k, v in module.__dict__.items(): + if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + continue + else: + config._add_item(k, v) + + logger = get_dist_logger() + logger.debug('variables which starts with __, is a module or class declaration are omitted in config file') + + # remove module + del sys.modules[module_name] + if remove_path: + sys.path.pop(0) + + return config + + +class ConfigException(Exception): + pass diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py new file mode 100644 index 0000000000000000000000000000000000000000..66f28e156a2885d3f4fa1d2d4bbdb040d6d15fa5 --- /dev/null +++ b/colossalai/context/moe_context.py @@ -0,0 +1,129 @@ +import torch +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor import ProcessGroup + +from typing import Tuple + + +def _check_sanity(): + from colossalai.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: + raise NotImplementedError("Moe is not compatible with tensor or " + "pipeline parallel at present.") + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups. + """ + + def __init__(self, ep_size: int, dp_size: int): + _check_sanity() + self.ep_size = ep_size + self.dp_size = dp_size + self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) + self.ep_group = self.pg.tp_process_group() + self.dp_group = self.pg.dp_process_group() + + +class MoeContext(metaclass=SingletonMeta): + """MoE parallel context manager. This class manages different + parallel groups in MoE context and MoE loss in training. + """ + + def __init__(self): + self.world_size = 1 + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = 1 + self.min_dp_size = 1 + self.aux_loss = None + self.use_kernel_optim = True + + self.has_setup = False + self._parallel_info_dict = dict() + + @property + def parallel_info_dict(self): + return self._parallel_info_dict + + @property + def is_initialized(self): + return self.has_setup + + def setup(self, seed: int, use_kernel_optim: bool = True): + assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + _check_sanity() + assert torch.cuda.is_available(), "MoE requires to enable CUDA first" + + self.world_size = dist.get_world_size() + + from colossalai.core import global_context as gpc + self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) + assert self.world_size % self.max_ep_size == 0, \ + "Maximum epxert parallel size must be a factor of the number of GPUs" + self.min_dp_size = self.world_size // self.max_ep_size + + # Enabling kernel optimization may raise error in some cases + # Users can close kernel optimization manually + self.use_kernel_optim = use_kernel_optim + + from .random import moe_set_seed + moe_set_seed(seed) + self.has_setup = True + + def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: + """Calculate the Data Parallel Group and Expert Parallel Group. + + Parameters + ---------- + num_experts : int + The number experts + + Returns + ------- + int, MoeParallelInfo + number of local experts, the MoeParallelInfo of the current ep_size + """ + + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + + # Calculate the number of experts for each GPU + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + if not (ep_size in self.parallel_info_dict): + self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) + + return num_local_experts, self.parallel_info_dict[ep_size] + + def set_kernel_not_use(self): + self.use_kernel_optim = False + + def reset_loss(self): + self.aux_loss = 0 + + def add_loss(self, loss): + self.aux_loss += loss + + def get_loss(self): + return self.aux_loss + + +MOE_CONTEXT = MoeContext() diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py new file mode 100644 index 0000000000000000000000000000000000000000..afa306065abe393520541f7804fe22223816b158 --- /dev/null +++ b/colossalai/context/parallel_context.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import random +import socket +from collections import Counter +from threading import local +from typing import Union + +import numpy as np +import torch +import torch.distributed as dist +from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING +from colossalai.context.config import Config +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.logging import get_dist_logger +from colossalai.registry import DIST_GROUP_INITIALIZER + +from .parallel_mode import ParallelMode +from .random import add_seed, get_seeds, set_mode +from colossalai.context.singleton_meta import SingletonMeta + + +class ParallelContext(metaclass=SingletonMeta): + """This class provides interface functions for users to get the parallel context, + such as the global rank, the local rank, the world size, etc. of each device. + + Note: + The parallel_mode used in this class should be concluded in ``ParallelMode``. + More details about ``ParallelMode`` could be found in + `parallel_mode `_. + """ + + def __init__(self): + # distributed settings + self._global_ranks = dict() + self._local_ranks = dict() + self._world_sizes = dict() + self._groups = dict() + self._cpu_groups = dict() + self._ranks_in_group = dict() + + # load config from file + self._config = None + + # default 3D parallel args, will be overwritten during process group intialization + self.world_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + self.tensor_parallel_size = 1 + self.num_processes_on_current_node = -1 + self.virtual_pipeline_parallel_size = None + self.virtual_pipeline_parallel_rank = None + + # logging + self._verbose = False + self._logger = get_dist_logger() + + @property + def config(self): + return self._config + + @property + def verbose(self): + return self._verbose + + @verbose.setter + def verbose(self, verbose_: bool): + self._verbose = verbose_ + + def load_config(self, config: Union[dict, str]): + """Loads the configuration from either a dict or a file. + + Args: + config (dict or str): Either a dict containing the configuration information or the filename + of a file containing the configuration information. + + Raises: + TypeError: Raises a TypeError if `config` is neither a dict nor a str. + """ + if isinstance(config, str): + self._config = Config.from_file(config) + elif isinstance(config, dict): + self._config = Config(config) + else: + raise TypeError("Invalid type for config, only dictionary or string is supported") + + def detect_num_processes_on_current_node(self): + hostname = socket.gethostname() + hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))] + dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL)) + counter = Counter(hostname_list) + self.num_processes_on_current_node = counter[hostname] + + @staticmethod + def _check_parallel_mode(parallel_mode: ParallelMode): + assert isinstance(parallel_mode, ParallelMode), \ + f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' + + def get_global_rank(self): + """Returns the global rank of the current device. + + Returns: + int: The global rank of the current device + """ + return self._global_ranks[ParallelMode.GLOBAL] + + def add_global_rank(self, parallel_mode: ParallelMode, rank: int): + """Adds the global rank of the current device for `parallel_mode` to the context. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + rank (int): The rank to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._global_ranks[parallel_mode] = rank + + def get_local_rank(self, parallel_mode: ParallelMode): + """Returns the local rank of the current device. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The local rank of the current device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + return self._local_ranks[parallel_mode] + + def _add_local_rank(self, parallel_mode: ParallelMode, rank: int): + """Adds the local rank of the current device for `parallel_mode` to the context. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + rank (int): The rank to be added. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._local_ranks[parallel_mode] = rank + + def get_next_global_rank(self, parallel_mode: ParallelMode): + """Returns the global rank of the next device. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The global rank of the next device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + + # get rank and world size + local_rank = self.get_local_rank(parallel_mode) + world_size = self.get_world_size(parallel_mode) + ranks_in_group = self.get_ranks_in_group(parallel_mode) + + return ranks_in_group[(local_rank + 1) % world_size] + + def get_prev_global_rank(self, parallel_mode: ParallelMode): + """Returns the global rank of the previous device. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The global rank of the previous device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + + # get rank and world size + local_rank = self.get_local_rank(parallel_mode) + world_size = self.get_world_size(parallel_mode) + ranks_in_group = self.get_ranks_in_group(parallel_mode) + + return ranks_in_group[(local_rank - 1) % world_size] + + def is_first_rank(self, parallel_mode: ParallelMode): + """Returns a boolean value indicating whether the current device is the first one + among its group for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + bool: a boolean value indicating whether the current device is the first one + among its group for `parallel_mode`. + """ + rank = self.get_local_rank(parallel_mode) + return rank == 0 + + def is_last_rank(self, parallel_mode: ParallelMode): + """Returns a boolean value indicating whether the current device is the last one + among its group for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + bool: a boolean value indicating whether the current device is the first one + among its group for `parallel_mode`. + """ + rank = self.get_local_rank(parallel_mode) + world_size = self.get_world_size(parallel_mode) + return rank == world_size - 1 + + def is_pipeline_first_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0: + return False + return self.is_first_rank(ParallelMode.PIPELINE) + + def is_pipeline_last_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self.virtual_pipeline_parallel_size \ + is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + return False + return self.is_last_rank(ParallelMode.PIPELINE) + + def get_world_size(self, parallel_mode: ParallelMode): + """Returns the world size for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The world size for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + return self._world_sizes[parallel_mode] + + def _add_world_size(self, parallel_mode: ParallelMode, world_size: int): + """Adds world size for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode correponding to the process group + world_size (int): The world size to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._world_sizes[parallel_mode] = world_size + + def get_group(self, parallel_mode: ParallelMode): + """Returns the group of the current device for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + return self._groups[parallel_mode] + + def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + """Adds the group of the current device for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + group (torch.distributed.ProcessGroup): The group to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._groups[parallel_mode] = group + + def get_cpu_group(self, parallel_mode: ParallelMode): + """Returns the Gloo group of the current device for `parallel_mode`. + + :param parallel_mode: The chosen parallel mode + :type parallel_mode: :class:`colossalai.context.ParallelMode` + :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode` + :return: The group of the current device for `parallel_mode` + :rtype: torch.distributed.ProcessGroup + """ + self._check_parallel_mode(parallel_mode) + return self._cpu_groups[parallel_mode] + + def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + """Adds the Gloo group of the current device for `parallel_mode`. + + :param parallel_mode: The chosen parallel mode + :type parallel_mode: :class:`colossalai.context.ParallelMode` + :param group: The group to be added + :type group: torch.distributed.ProcessGroup + :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode` + """ + self._check_parallel_mode(parallel_mode) + self._cpu_groups[parallel_mode] = group + + def get_ranks_in_group(self, parallel_mode: ParallelMode): + """Returns the rank of the current device for `parallel_mode` in the group. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The rank of the current device for `parallel_mode` in the group. + """ + self._check_parallel_mode(parallel_mode) + return self._ranks_in_group[parallel_mode] + + def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): + """Adds the ranks of the current device for `parallel_mode` in the group. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + ranks (list): List of ranks to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._ranks_in_group[parallel_mode] = ranks + + def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int): + """Initializes the global distributed environment + + Args: + rank (int): rank for the default process group. + world_size (int): world size of the default process group. + backend (str): backend for ``torch.distributed`` + host (str): the master address for distributed training. + port (str): the master port for distributed training + """ + # initialize the default process group + init_method = f'tcp://{host}:{port}' + dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) + + # None will give the default global process group for pytorch dist operations + ranks = list(range(world_size)) + cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) + self.add_global_rank(ParallelMode.GLOBAL, rank) + + def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): + self._add_local_rank(mode, local_rank) + self._add_world_size(mode, world_size) + self._add_group(mode, process_group) + self._add_cpu_group(mode, cpu_group) + self._add_ranks_in_group(mode, ranks_in_group) + + def check_sanity(self): + """Checks sanity of the parallel context. + + Raises: + AssertionError: Raises an AssertionError if the world size does not equal to the product + of data parallel size, pipeline parallel size and tensor parallel size. + """ + dps = self.data_parallel_size + pps = self.pipeline_parallel_size + tps = self.tensor_parallel_size + ws = self.world_size + assert ws == dps * pps * \ + tps, f"Expected the world size {ws} to be equal to data" \ + f" parallel size ({dps}) * pipeline parallel size " \ + f"({pps}) * tensor parallel size ({tps})" + + def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): + if key in config: + ele = config[key] + if isinstance(ele, int): + setattr(self, attr_name, ele) + elif isinstance(ele, dict): + setattr(self, attr_name, ele['size']) + else: + raise NotImplementedError( + f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') + + def init_parallel_groups(self): + """Initializes the parallel groups. + + Raises: + AssertionError: Raises an AssertionError if the field parallel is not present in the config file. + """ + + # get rank and world size + rank = self.get_global_rank() + world_size = self.get_world_size(ParallelMode.GLOBAL) + self.world_size = world_size + + # set parallel size as attributes for global context + parallel_config = self.config.get('parallel', None) + if parallel_config is not None: + self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size') + self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size') + + # the user should not set the data parallel size manually + # instead, it should be calculated based on other parallel config + self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size) + + # get the tensor parallel mode and check + tensor_parallel_mode = None + if parallel_config is not None and 'tensor' in \ + parallel_config and 'mode' in parallel_config['tensor']: + tensor_parallel_mode = parallel_config['tensor']['mode'] + assert tensor_parallel_mode in ALLOWED_MODES, \ + f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + env.mode = tensor_parallel_mode + + self.check_sanity() + + pg_init = [] + # LSG: init data parallel process group for compatibility with other parallel module such as zero + pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) + + # LSG: init model parallel process group for compatibility with amp and clip grad + pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) + + if self.pipeline_parallel_size > 1: + pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) + pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) + + # init specific tensor parallel group + if tensor_parallel_mode is not None: + tensor_parallel_cfg = parallel_config['tensor'].copy() + + # remove duplicate parameters + tensor_parallel_cfg.pop('mode') + tensor_parallel_cfg.pop('size') + + # add this config to initialize later + pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) + + # run initialization of different process groups + for initializer_cfg in pg_init: + cfg = initializer_cfg.copy() + initializer_type = cfg.pop('type') + initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, + self.data_parallel_size, + self.pipeline_parallel_size, + self.tensor_parallel_size, **cfg) + parallel_setting = initializer.init_dist_group() + if isinstance(parallel_setting, list): + for args in parallel_setting: + self._register_dist(*args) + else: + self._register_dist(*parallel_setting) + + def is_initialized(self, parallel_mode: ParallelMode): + """Returns a boolean value indicating whether `parallel_mode` is initialized + in the current system. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Returns: + bool: a boolean value indicating whether `parallel_mode` is initialized in the current system. + """ + return parallel_mode in self._groups + + def destroy(self): + """Destroys the current distributed parallel environment. + """ + for mode, group in self._groups.items(): + if mode is not ParallelMode.GLOBAL: + dist.destroy_process_group(group) + # destroy global process group + dist.destroy_process_group() + self._groups.clear() + + def set_device(self, device_ordinal: int = None): + """Sets distributed processes to be bound to devices. + + Args: + device_ordinal (int, optional): the device id to be bound to + """ + global_rank = self.get_global_rank() + if device_ordinal is None: + devices_per_node = torch.cuda.device_count() + device_ordinal = global_rank % devices_per_node + + torch.cuda.set_device(device_ordinal) + if self._verbose: + self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}') + + def set_seed(self, seed: int): + """Sets seeds for all random libraries. + + Args: + seed (int): seed for random states + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + global_rank = self.get_global_rank() + + if torch.cuda.is_available(): + # create random seed for different parallel modes + # data parallel seed are kept the same + parallel_seed = seed + add_seed(ParallelMode.DATA, parallel_seed) + + # model parallel seeds are different across ranks + pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0) + + # add seed for data parallel and tensor parallel only + if self.is_initialized(ParallelMode.TENSOR): + tp_rank = self.get_local_rank(ParallelMode.TENSOR) + # 100 is only to increase the diff in seeds between pipeline stages + tp_rank_with_offset = tp_rank + pipeline_offset * 1024 + tp_seed = seed + tp_rank_with_offset + add_seed(ParallelMode.TENSOR, tp_seed) + + set_mode(ParallelMode.DATA) + seeds = get_seeds() + seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) + + if self._verbose: + self._logger.info(f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, {seed_str}," + f"the default parallel seed is {ParallelMode.DATA}.") + else: + if self._verbose: + self._logger.info( + f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, pytorch: {seed}", + ranks=[0]) + self._logger.info( + 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', + ranks=[0]) + + def set_virtual_pipeline_parallel_size(self, size): + self.virtual_pipeline_parallel_size = size + + def set_virtual_pipeline_parallel_rank(self, rank): + self.virtual_pipeline_parallel_rank = rank + + +global_context = ParallelContext() diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf6fa53dc1e5c31fbaf1c9140e0915419af704c --- /dev/null +++ b/colossalai/context/parallel_mode.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +# parallel modes +class ParallelMode(Enum): + """This is an enumeration class containing all possible parallel modes. + """ + + GLOBAL = 'global' + + # common parallel + DATA = 'data' + + # model parallel - containing tensor and pipeline parallel groups + # this is added to facilitate amp and grad clipping in hybrid parallel + MODEL = 'model' + + # pipeline parallel + PIPELINE = 'pipe' + + # containing all ranks in tensor parallel + TENSOR = 'tensor' + + # sequence parallel + SEQUENCE = 'sequence' + SEQUENCE_DP = 'sequence_dp' + + # 1D Parallel + PARALLEL_1D = '1d' + + # 2D parallel + PARALLEL_2D_ROW = '2d_row' + PARALLEL_2D_COL = '2d_col' + + # 3D parallel + PARALLEL_3D_INPUT = '3d_input' + PARALLEL_3D_WEIGHT = '3d_weight' + PARALLEL_3D_OUTPUT = '3d_output' + PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" + PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" + + # 2.5D parallel + PARALLEL_2P5D_ROW = '2p5d_row' + PARALLEL_2P5D_COL = '2p5d_col' + PARALLEL_2P5D_DEP = '2p5d_dep' + PARALLEL_2P5D_XZ = '2p5d_xz' diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3937a9474376f0ecb7af612121bb4c3e5f5a497 --- /dev/null +++ b/colossalai/context/process_group_initializer/__init__.py @@ -0,0 +1,15 @@ +from .initializer_1d import Initializer_1D +from .initializer_2d import Initializer_2D +from .initializer_2p5d import Initializer_2p5D +from .initializer_3d import Initializer_3D +from .initializer_data import Initializer_Data +from .initializer_pipeline import Initializer_Pipeline +from .initializer_sequence import Initializer_Sequence +from .initializer_tensor import Initializer_Tensor +from .initializer_model import Initializer_Model +from .process_group_initializer import ProcessGroupInitializer + +__all__ = [ + 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', + 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' +] diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..4c05028041cef2a9ad453b05d17d35b09ec2617d --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_1d.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_1D(ProcessGroupInitializer): + """A ProcessGroupInitializer for 1d tensor parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_group = self.world_size // self.tensor_parallel_size + + def init_dist_group(self): + """Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 1D tensor parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_1D + env.parallel_input_1d = False + + for i in range(self.num_group): + ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0ba553d6f35e0e6c65f2241c843e2777d38565 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_2d.py @@ -0,0 +1,154 @@ +import math + +import torch.distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode +from colossalai.global_variables import tensor_parallel_env as env + + +def _check_summa_env_var(summa_dim): + # check environment variable for SUMMA + env_summa_dim = env.summa_dim + + if env_summa_dim: + assert int(env_summa_dim) == summa_dim, \ + 'SUMMA_DIM has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + else: + env.summa_dim = summa_dim + + +class Initializer_2D_Row(ProcessGroupInitializer): + """2d tensor parallel initialization among rows. + + Args: + num_group (int): The number of all tensor groups. + summa_dim (int): The dimension of SUMMA. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group, summa_dim, *args, **kwargs): + super(Initializer_2D_Row, self).__init__(*args, **kwargs) + self.num_group = num_group + self.summa_dim = summa_dim + + def init_dist_group(self): + """Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2D tensor row parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2D_ROW + + for i in range(self.num_group): + for j in range(self.summa_dim): + ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_2D_Col(ProcessGroupInitializer): + """2d tensor parallel initialization among cols. + + Args: + num_group (int): The number of all tensor groups. + summa_dim (int): The dimension of SUMMA. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group, summa_dim, *args, **kwargs): + super(Initializer_2D_Col, self).__init__(*args, **kwargs) + self.num_group = num_group + self.summa_dim = summa_dim + + def init_dist_group(self): + """Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2D tensor col parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2D_COL + + for i in range(self.num_group): + for j in range(self.summa_dim): + ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_2D(ProcessGroupInitializer): + """ + Serve as the single entry point to 2D parallel initialization. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_group = self.world_size // self.tensor_parallel_size + self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) + + assert self.tensor_parallel_size == self.summa_dim ** 2, \ + "2D summa dim should equal to tensor parallel size ^ 0.5" + _check_summa_env_var(self.summa_dim) + + self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) + self.row_initializer = Initializer_2D_Row(self.num_group, self.summa_dim, *args, **kwargs) + + def init_dist_group(self): + """Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + 2D tensor parallelism's information in a list of tuples. + """ + parallel_setting = [self.row_initializer.init_dist_group(), self.col_initializer.init_dist_group()] + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6fdc5d715c30169f04cef54abd946c4c46b904 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_2p5d.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch.distributed as dist +from colossalai.context import Config +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): + # check global variable for TESSERACT + env_tesseract_dim = env.tesseract_dim + env_tesseract_dep = env.tesseract_dep + + if env_tesseract_dim and env_tesseract_dep: + assert int(env_tesseract_dim) == tesseract_dim, \ + 'TESSERACT_DIM has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + assert int(env_tesseract_dep) == tesseract_dep, \ + 'TESSERACT_DEP has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + else: + env.tesseract_dim = tesseract_dim + env.tesseract_dep = tesseract_dep + + +# i row j col k dep +class Initializer_2p5D_ROW(ProcessGroupInitializer): + """2.5d tensor parallel initialization among rows. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_ROW, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ + "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" + + def init_dist_group(self): + """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor row parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_ROW + + for h in range(self.num_group): + for j in range(self.tesseract_dim): + for k in range(self.tesseract_dep): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for i in range(self.tesseract_dim) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_2p5D_Col(ProcessGroupInitializer): + """2.5d tensor parallel initialization among cols. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_Col, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + + def init_dist_group(self): + """Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor col parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_COL + + for h in range(self.num_group): + for i in range(self.tesseract_dim): + for k in range(self.tesseract_dep): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for j in range(self.tesseract_dim) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_2p5D_Dep(ProcessGroupInitializer): + """2.5D tensor parallel initialization among depths. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_Dep, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + + def init_dist_group(self): + """Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor depth parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_DEP + + for h in range(self.num_group): + for i in range(self.tesseract_dim): + for j in range(self.tesseract_dim): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for k in range(self.tesseract_dep) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +# i row j col k dep +class Initializer_2p5D_XZ(ProcessGroupInitializer): + """2.5d tensor parallel initialization among cols times dep. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_XZ, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + + def init_dist_group(self): + """Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor colXdepth parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_XZ + + for h in range(self.num_group): + for i in range(self.tesseract_dim): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for k in range(self.tesseract_dep) + for j in range(self.tesseract_dim) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_2p5D(ProcessGroupInitializer): + """ + Serve as the single entry point to Tesseract parallel initialization. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + depth (int): The depth of 2.5d parallel. + """ + + def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, + tensor_parallel_size: int, depth: int): + args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) + super().__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) + self.tesseract_dep = depth + + assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ + "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" + _check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) + + self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) + self.row_initializer = Initializer_2p5D_ROW(self.tesseract_dim, self.tesseract_dep, *args) + self.dep_initializer = Initializer_2p5D_Dep(self.tesseract_dim, self.tesseract_dep, *args) + self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args) + + def init_dist_group(self): + """Initialize 2.5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + Whole 2.5D tensor parallelism's information in a list of tuples. + """ + parallel_setting = [ + self.col_initializer.init_dist_group(), + self.row_initializer.init_dist_group(), + self.dep_initializer.init_dist_group(), + self.xz_initializer.init_dist_group() + ] + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..b752b8f456540672ad02e2388f228054228aca95 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch.distributed as dist +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +def _check_depth_env_var(depth): + # check global variable + env_depth = env.depth_3d + + if env_depth: + assert int(env_depth) == depth, \ + 'DEPTH_3D has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + else: + env.depth_3d = depth + + +class Initializer_3D_Input(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_INPUT + env.input_group_3d = mode + + for h in range(self.num_group): + for i in range(self.depth): + for k in range(self.depth): + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_Weight(ProcessGroupInitializer): + """3D tensor parallel initialization among weight. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among weight in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_WEIGHT + env.weight_group_3d = mode + + for h in range(self.num_group): + for k in range(self.depth): + for j in range(self.depth): + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_Output(ProcessGroupInitializer): + """3D tensor parallel initialization among output. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among output in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_OUTPUT + env.output_group_3d = mode + + for h in range(self.num_group): + for i in range(self.depth): + for j in range(self.depth): + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_InputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT + env.input_x_weight_group_3d = mode + + for h in range(self.num_group): + for k in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_OutputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT + env.output_x_weight_group_3d = mode + + for h in range(self.num_group): + for j in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_3D(ProcessGroupInitializer): + """Serve as the single entry point to 3D parallel initialization. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args): + super().__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) + assert self.tensor_parallel_size == self.depth ** 3, \ + f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' + _check_depth_env_var(self.depth) + + self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) + self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args) + self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args) + self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args) + self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args) + + def init_dist_group(self): + """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + Whole 3D tensor parallelism's information in a list of tuples. + """ + parallel_setting = [ + self.input_initializer.init_dist_group(), + self.weight_initializer.init_dist_group(), + self.output_initializer.init_dist_group(), + self.input_x_weight_initializer.init_dist_group(), + self.output_x_weight_initializer.init_dist_group() + ] + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8b0d91fcb9b7045ab2f7d6cc6948bee0397469 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_data.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch import distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Data(ProcessGroupInitializer): + """A ProcessGroupInitializer for data parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_data_parallel_group = self.world_size // self.data_parallel_size + + def init_dist_group(self): + """Initialize data parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Data parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.DATA + + for i in range(self.num_data_parallel_group): + ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..99b9cc0d4edce35915c52c01fa5875545256ba97 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Model(ProcessGroupInitializer): + """A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel + groups). + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size + self.num_group = self.world_size // self.model_parallel_size + + def init_dist_group(self): + """Initialize model parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Model parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.MODEL + + for i in range(self.num_group): + ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..edd1a3706c6863b187f49ef7c2fcf0d53afcdddf --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_pipeline.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch import distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Pipeline(ProcessGroupInitializer): + """A ProcessGroupInitializer for pipeline parallelism. + + Args: + rank (int): The rank of current process + world_size (int): Size of whole communication world + config (Config): Running configuration + data_parallel_size (int): Size of data parallel + pipeline_parallel_size (int): Size of pipeline parallel + tensor_parallel_size (int): Size of tensor parallel + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data_group_size = self.world_size // self.data_parallel_size + self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size + + def init_dist_group(self): + """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + A Pipeline parallelism's information in list of tuples. + """ + dist_settings = list() + for i in range(self.data_parallel_size): + for j in range(self.pipeline_stage_size): + pipe_ranks = list( + range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) + pipe_group_size = len(pipe_ranks) + pipe_group = dist.new_group(pipe_ranks) + group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group + + if self.rank in pipe_ranks: + local_rank = pipe_ranks.index(self.rank) + group_world_size = pipe_group_size + process_group = pipe_group + cpu_group = group_cpu + ranks_in_group = pipe_ranks + dist_settings.append( + tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, + ParallelMode.PIPELINE))) + + return dist_settings diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..682fe4bb7633e4e2c158b0485baccb5a00691630 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_sequence.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import torch.distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from .initializer_tensor import Initializer_Tensor +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Sequence_DP(ProcessGroupInitializer): + """A ProcessGroupInitializer for sequence parallelism all-reduce. + + In Sequence Parallelism, each GPU holds the full copy of model weights, + thus, gradient all-reduce occurs across all processes in the same pipeline stage + + Args: + rank (int): The rank of current process + world_size (int): Size of whole communication world + config (Config): Running configuration + data_parallel_size (int): Size of data parallel + pipeline_parallel_size (int): Size of pipeline parallel + tensor_parallel_size (int): Size of tensor parallel + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dp_size = self.world_size // self.pipeline_parallel_size + self.num_group = self.pipeline_parallel_size + + def init_dist_group(self): + """Initialize Sequence Parallel process groups used for gradient all-reduce. + + Returns: + Tuple: A tuple (local_rank, group_world_size, process_group, ranks_in_group, mode). + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.SEQUENCE_DP + + for i in range(self.num_group): + ranks = [i * self.dp_size + j for j in range(self.dp_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Sequence(ProcessGroupInitializer): + """A ProcessGroupInitializer for sequence parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # reuse tensor parallel initializer code + self._sequence_initializer = Initializer_Tensor(*args, **kwargs) + self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs) + + def init_dist_group(self): + """Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu. + + Sequence parallelism requires 2 process groups. The first is for model forward where several processes + exchange partial query, key and value embedding to compute self attention values. The second is for + all-reduce to synchronize the model parameters. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + A Sequence parallelism's information in list of tuples. + """ + + parallel_setting = [] + + local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \ + self._sequence_initializer.init_dist_group() + # change mode to sequence + mode = ParallelMode.SEQUENCE + + parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode)) + parallel_setting.append(self._sequence_dp_initializer.init_dist_group()) + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b5be9cfffbe9eb7234411c6526d4055c078f12 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_tensor.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Tensor(ProcessGroupInitializer): + """A ProcessGroupInitializer for tensor parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size + + def init_dist_group(self): + """Initialize tensor parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Tensor parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.TENSOR + + for i in range(self.num_tensor_parallel_group): + ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/process_group_initializer.py b/colossalai/context/process_group_initializer/process_group_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..98150ce8e428a3b9bf81185719685b38efc2bdfd --- /dev/null +++ b/colossalai/context/process_group_initializer/process_group_initializer.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + +from colossalai.context import Config + + +class ProcessGroupInitializer(ABC): + """An object, knowing the parallelism configuration, that initializes parallel groups. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, + tensor_parallel_size: int): + self.rank = rank + self.world_size = world_size + self.data_parallel_size = data_parallel_size + self.config = config + self.pipeline_parallel_size = pipeline_parallel_size + self.tensor_parallel_size = tensor_parallel_size + super().__init__() + + @abstractmethod + def init_dist_group(self): + pass diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..422c3676c09d6698c7dd81b2b1c5d3c3b3a0bc50 --- /dev/null +++ b/colossalai/context/random/__init__.py @@ -0,0 +1,7 @@ +from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states, + sync_states, moe_set_seed, reset_seeds) + +__all__ = [ + 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', + 'sync_states', 'moe_set_seed', 'reset_seeds' +] diff --git a/colossalai/context/random/_helper.py b/colossalai/context/random/_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..973c4d9faa325820aa1dedc5e133551430778057 --- /dev/null +++ b/colossalai/context/random/_helper.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import functools +from contextlib import contextmanager + +import torch.cuda +from torch import Tensor + +from .seed_manager import SeedManager +from ..parallel_mode import ParallelMode + +_SEED_MANAGER = SeedManager() + + +def get_seeds(): + """Returns the seeds of the seed manager. + + Returns: + dict: The seeds of the seed manager. + """ + return _SEED_MANAGER.seeds + + +def get_states(copy=False): + """Returns the seed states of the seed manager. + + Returns: + dict: The seed states of the seed manager. + """ + states = _SEED_MANAGER.seed_states + + if copy: + new_states = dict() + + for parallel_mode, state in states.items(): + new_states[parallel_mode] = state.clone() + return new_states + else: + return _SEED_MANAGER.seed_states + + +def get_current_mode(): + """Returns the current mode of the seed manager. + + Returns: + :class:`torch.ByteTensor`: The current mode of the seed manager. + """ + return _SEED_MANAGER.current_mode + + +def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): + """Adds a seed to the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + seed (int): The seed to be added + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of + :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite) + + +def set_mode(parallel_mode: ParallelMode): + """Sets the current mode of the seed manager. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + _SEED_MANAGER.set_mode(parallel_mode) + + +def set_seed_states(parallel_mode: ParallelMode, state: Tensor): + """Sets the state of the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + state (:class:`torch.Tensor`): the state to be set. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. + """ + _SEED_MANAGER.set_state(parallel_mode, state) + + +def sync_states(): + current_mode = get_current_mode() + current_states = torch.cuda.get_rng_state() + set_seed_states(current_mode, current_states) + + +@contextmanager +def seed(parallel_mode: ParallelMode): + """ A context for seed switch + + Examples: + + >>> with seed(ParallelMode.DATA): + >>> output = F.dropout(input) + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + try: + # set to new mode + current_mode = _SEED_MANAGER.current_mode + yield _SEED_MANAGER.set_mode(parallel_mode) + finally: + # recover + _SEED_MANAGER.set_mode(current_mode) + + +def with_seed(func, parallel_mode: ParallelMode): + """ + A function wrapper which executes the function with a specified seed. + + Examples: + + >>> # use with decorator + >>> @with_seed(ParallelMode.DATA) + >>> def forward(input): + >>> return F.dropout(input) + >>> out = forward(input) + >>> # OR use it inline + >>> def forward(input): + >>> return F.dropout(input) + >>> wrapper_forward = with_seed(forward, ParallelMode.DATA) + >>> out = wrapped_forward(input) + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # switch mode + current_mode = _SEED_MANAGER.current_mode + _SEED_MANAGER.set_mode(parallel_mode) + + # exec func + out = func(*args, **kwargs) + + # recover state + _SEED_MANAGER.set_mode(current_mode) + + return out + + return wrapper + + +def moe_set_seed(seed): + if torch.cuda.is_available(): + from colossalai.core import global_context as gpc + global_rank = gpc.get_global_rank() + diff_seed = seed + global_rank + add_seed(ParallelMode.TENSOR, diff_seed, True) + print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True) + + +def reset_seeds(): + _SEED_MANAGER.reset() diff --git a/colossalai/context/random/seed_manager.py b/colossalai/context/random/seed_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3c84aaafc179a72e89b828ca794a985a3f340be5 --- /dev/null +++ b/colossalai/context/random/seed_manager.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch import Tensor + +from colossalai.context.parallel_mode import ParallelMode + + +class SeedManager: + """This class is a manager of all random seeds involved in the system. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + def __init__(self): + self._current_mode = None + self._seeds = dict() + self._seed_states = dict() + + @property + def current_mode(self): + return self._current_mode + + @property + def seeds(self): + return self._seeds + + @property + def seed_states(self): + return self._seed_states + + def set_state(self, parallel_mode: ParallelMode, state: Tensor): + """Sets the state of the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + state (:class:`torch.Tensor`): the state to be set. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. + """ + assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' + self._seed_states[parallel_mode] = state + + def set_mode(self, parallel_mode: ParallelMode): + """Sets the current mode of the seed manager. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + """ + if self.current_mode: + # save the current state for current mode + self._seed_states[self._current_mode] = torch.cuda.get_rng_state() + + # set the new state for new mode + self._current_mode = parallel_mode + torch.cuda.set_rng_state(self._seed_states[parallel_mode]) + + def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False): + """Adds a seed to the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + seed (int): The seed to be added. + overwrtie (bool, optional): Whether allows to overwrite the seed that has been set already + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` + or the seed for `parallel_mode` has been added. + """ + assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' + if overwrtie is False: + assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + elif parallel_mode in self._seed_states: + print(f"Warnning: {parallel_mode} seed has been overwritten.", flush=True) + + current_state = torch.cuda.get_rng_state() + torch.cuda.manual_seed(seed) + self._seed_states[parallel_mode] = torch.cuda.get_rng_state() + self._seeds[parallel_mode] = seed + torch.cuda.set_rng_state(current_state) + + def reset(self): + self._current_mode = None + self._seeds = dict() + self._seed_states = dict() diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca335119d52ad2a212b1e0c578202b2fc6bb60f --- /dev/null +++ b/colossalai/context/singleton_meta.py @@ -0,0 +1,21 @@ +class SingletonMeta(type): + """ + The Singleton class can be implemented in different ways in Python. Some + possible methods include: base class, decorator, metaclass. We will use the + metaclass because it is best suited for this purpose. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """ + Possible changes to the value of the `__init__` argument do not affect + the returned instance. + """ + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + else: + assert len(args) == 0 and len( + kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' + return cls._instances[cls] diff --git a/colossalai/core.py b/colossalai/core.py new file mode 100644 index 0000000000000000000000000000000000000000..153247bbed9c65db0b2255247137fa9a64a693fa --- /dev/null +++ b/colossalai/core.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from colossalai.context.parallel_context import global_context + +__all__ = ['global_context'] \ No newline at end of file diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..879b60c06a59bad14e14a2757df0b8b88798df25 --- /dev/null +++ b/colossalai/device/__init__.py @@ -0,0 +1,4 @@ +from .calc_pipeline_strategy import alpa_dp +from .profile_alpha_beta import profile_alpha_beta + +__all__ = ['profile_alpha_beta', 'alpa_dp'] diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab72dfe60f0c73f0e4f5186ed54205b68513bc0 --- /dev/null +++ b/colossalai/device/calc_pipeline_strategy.py @@ -0,0 +1,127 @@ +from math import pow + +import numpy as np + + +def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): + submesh_choices = [] + i = 1 + p = -1 + while i <= num_devices_per_host: + i *= 2 + p += 1 + assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}") + if mode == "alpa": + for i in range(p + 1): + submesh_choices.append((1, pow(2, i))) + for i in range(2, num_hosts + 1): + submesh_choices.append((i, num_devices_per_host)) + elif mode == "new": + for i in range(p // 2 + 1): + for j in range(i, p - i + 1): + submesh_choices.append((pow(2, i), pow(2, j))) + return submesh_choices + + +def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, + best_configs): + """Implementation of Alpa DP for pipeline strategy + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ + # For f, layer ID start from 0 + # f[#pipeline stages, layer id that is currently being considered, number of devices used] + f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) + f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32) + f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) + f[0, num_layers, 0] = 0 + for s in range(1, num_layers + 1): + for k in range(num_layers - 1, -1, -1): + for d in range(1, num_devices + 1): + for m, submesh in enumerate(submesh_choices): + n_submesh_devices = np.prod(np.array(submesh)) + if n_submesh_devices <= d: + # TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete. + # if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]: + # ... + for i in range(num_layers, k, -1): + stage_cost = compute_cost[k, i, m] + new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost + if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + f[s, k, d] = new_cost + f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) + f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) + best_s = -1 + best_total_cost = np.inf + for s in range(1, num_layers + 1): + if f[s, 0, num_devices] < best_total_cost: + best_s = s + best_total_cost = f[s, 0, num_devices] + + if np.isinf(best_total_cost): + return np.inf, None + + total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices] + current_s = best_s + current_layer = 0 + current_devices = num_devices + + res = [] + while current_s > 0 and current_layer < num_layers and current_devices > 0: + next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + assert next_start_layer != -1 and current_devices != -1 + res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) + current_s -= 1 + current_layer = next_start_layer + current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) + assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + + return total_cost, res + + +def alpa_dp(num_layers, + num_devices, + num_microbatches, + submesh_choices, + num_autosharding_configs, + compute_cost, + gap=1e-6): + """Alpa auto stage dynamic programming. + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + + Arguments: + submesh_choices: List[(int,int)] + num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) + compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) + """ + assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), + num_autosharding_configs), "Cost shape wrong." + all_possible_stage_costs = np.sort(np.unique(compute_cost)) + best_cost = np.inf + best_solution = None + last_max_stage_cost = 0.0 + # TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost? + # In dp_impl it seems the argmin n_config will be chosen. Just amin here. + best_configs = np.argmin(compute_cost, axis=3) + best_compute_cost = np.amin(compute_cost, axis=3) + assert len(all_possible_stage_costs), "no solution in auto stage construction." + for max_stage_cost in all_possible_stage_costs: + if max_stage_cost * num_microbatches >= best_cost: + break + if max_stage_cost - last_max_stage_cost < gap: + continue + cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, + max_stage_cost, best_configs) + if cost < best_cost: + best_cost = cost + best_solution = solution + last_max_stage_cost = max_stage_cost + + return best_cost, best_solution diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..7596a100bf9310ece3c2201f4ee45ab869aeaea3 --- /dev/null +++ b/colossalai/device/device_mesh.py @@ -0,0 +1,240 @@ +import operator +from functools import reduce + +import torch +import torch.distributed as dist + + +class DeviceMesh: + """A logical view of a physical mesh. The logical view is used in the + search process. + A physical mesh can have multiple logical views. (e.g., a 2x8 physical mesh + can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its + own latency and bandwidth. We use alpha-beta model to model the + communication cost. + + Arguments: + physical_mesh_id (torch.Tensor): physical view of the devices in global rank. + mesh_shape (torch.Size): shape of logical view. + mesh_alpha (List[float], optional): coefficients used for computing + communication cost (default: None) + mesh_beta (List[float], optional): coefficients used for computing + communication cost (default: None) + init_process_group (bool, optional): initialize logical process group + during initializing the DeviceMesh instance if the init_process_group set to True. + Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. + (default: False) + need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + """ + + def __init__(self, + physical_mesh_id, + mesh_shape, + mesh_alpha=None, + mesh_beta=None, + init_process_group=False, + need_flatten=True): + self.physical_mesh_id = physical_mesh_id + self.mesh_shape = mesh_shape + self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + # map global rank into logical rank + self.convert_map = {} + self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # coefficient for alpha-beta communication model + if mesh_alpha is None: + mesh_alpha = [1] * len(self.mesh_shape) + if mesh_beta is None: + mesh_beta = [1] * len(self.mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) + self.mesh_beta = tuple(mesh_beta) + self.init_process_group = init_process_group + self.need_flatten = need_flatten + if self.init_process_group: + self.process_groups_dict = self.create_process_groups_for_logical_mesh() + if self.need_flatten and self._logical_mesh_id.dim() > 1: + self.flatten_device_mesh = self.flatten() + # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) + self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, + self.mesh_beta) + + @property + def shape(self): + return self.mesh_shape + + @property + def num_devices(self): + return reduce(operator.mul, self.physical_mesh_id.shape, 1) + + @property + def logical_mesh_id(self): + return self._logical_mesh_id + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k != 'process_groups_dict': + setattr(result, k, __import__("copy").deepcopy(v, memo)) + else: + setattr(result, k, v) + + return result + + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + flatten_mesh_shape_size = len(self.mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self.physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self.init_process_group, + need_flatten=False) + + def _global_rank_to_logical_rank_map(self, tensor, index_list): + ''' + This method is a helper function to build convert_map recursively. + ''' + for index, inner_tensor in enumerate(tensor): + if inner_tensor.numel() == 1: + self.convert_map[int(inner_tensor)] = index_list + [index] + else: + self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) + + def create_process_groups_for_logical_mesh(self): + ''' + This method is used to initialize the logical process groups which will be used in communications + among logical device mesh. + Note: if init_process_group set to False, you have to call this method manually. Otherwise, + the communication related function, such as ShapeConsistencyManager.apply will raise errors. + ''' + process_groups_dict = {} + check_duplicate_list = [] + global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() + for global_rank in global_rank_flatten_list: + process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) + for axis, process_group in process_groups.items(): + if axis not in process_groups_dict: + process_groups_dict[axis] = [] + if process_group not in check_duplicate_list: + check_duplicate_list.append(process_group) + process_group_handler = dist.new_group(process_group) + process_groups_dict[axis].append((process_group, process_group_handler)) + + return process_groups_dict + + def global_rank_to_logical_rank(self, rank): + return self.convert_map[rank] + + def global_rank_to_process_groups_with_logical_rank(self, rank): + ''' + Give a global rank and return all logical process groups of this rank. + for example: + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) + output: + # key is axis name + # value is a list of logical ranks in same axis with rank 0 + {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} + ''' + process_groups = {} + for d in range(self.logical_mesh_id.dim()): + for replacer in range(self.logical_mesh_id.shape[d]): + if d not in process_groups: + process_groups[d] = [] + process_group_member = self.convert_map[rank].copy() + process_group_member[d] = replacer + process_groups[d].append(process_group_member) + return process_groups + + def global_rank_to_process_groups_with_global_rank(self, rank): + ''' + Give a global rank and return all process groups of this rank. + for example: + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) + output: + # key is axis name + # value is a list of global ranks in same axis with rank 0 + {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} + ''' + logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) + process_groups = {} + for dim, logical_ranks in logical_process_groups.items(): + process_groups[dim] = [] + for logical_rank in logical_ranks: + for g_rank, l_rank in self.convert_map.items(): + if l_rank == logical_rank: + process_groups[dim].append(g_rank) + return process_groups + + def all_gather_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + + 0.1) + + def all_reduce_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + + 0.01) + + def reduce_scatter_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + + 0.001) + + def all_to_all_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + penalty_factor = num_devices / 2.0 + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + + +class FlattenDeviceMesh(DeviceMesh): + + def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): + super().__init__(physical_mesh_id, + mesh_shape, + mesh_alpha, + mesh_beta, + init_process_group=False, + need_flatten=False) + # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars + self.mesh_alpha = max(self.mesh_alpha) + self.mesh_beta = min(self.mesh_beta) + # Different from original process_groups_dict, rank_list is not stored + self.process_number_dict = self.create_process_numbers_for_logical_mesh() + + def create_process_numbers_for_logical_mesh(self): + ''' + Build 1d DeviceMesh in column-major(0) and row-major(1) + for example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + num_devices = reduce(operator.mul, self.mesh_shape, 1) + process_numbers_dict = {} + process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() + process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() + return process_numbers_dict + + def mix_gather_cost(self, num_bytes): + num_devices = reduce(operator.mul, self.mesh_shape, 1) + return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) diff --git a/colossalai/device/profile_alpha_beta.py b/colossalai/device/profile_alpha_beta.py new file mode 100644 index 0000000000000000000000000000000000000000..2d053ddbec92188b52975c418be07e9de456e599 --- /dev/null +++ b/colossalai/device/profile_alpha_beta.py @@ -0,0 +1,120 @@ +import fcntl +import math +import os +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +MB = int((1 << 10) * 1e3) +GB = int((1 << 20) * 1e3) +Byte = 4 +FRAMEWORK = 0 +NON_SENSE = (0.1, 0.1) + + +def printflock(*msgs): + """ solves multi-process interleaved print problem """ + with open(__file__, "r") as fh: + fcntl.flock(fh, fcntl.LOCK_EX) + try: + print(*msgs) + finally: + fcntl.flock(fh, fcntl.LOCK_UN) + + +def profile(device1d, nbytes, ctype): + warmup = 5 + repeat = 25 + rank = dist.get_rank() + src_device_num = device1d[0] + wsize = len(device1d) + group = dist.new_group(device1d) + + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + buf = torch.randn(nbytes // 4).to(device) + + torch.cuda.synchronize() + # warmup + for _ in range(warmup): + if ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group) + elif ctype == "b": + dist.broadcast(buf, src=src_device_num, group=group) + torch.cuda.synchronize() + + dist.barrier() + begin = time.perf_counter() + for _ in range(repeat): + if ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group) + elif ctype == "b": + dist.broadcast(buf, src=src_device_num, group=group) + torch.cuda.synchronize() + end = time.perf_counter() + dist.barrier() + + if rank == src_device_num: + avg_time_s = (end - begin) / repeat - FRAMEWORK + alg_band = nbytes / avg_time_s + if ctype == "b": + bus_band = alg_band + elif ctype == "a": + bus_band = 2 * (wsize - 1) / wsize * alg_band + print( + f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" + ) + return (avg_time_s, alg_band) + else: + return NON_SENSE # Just a placeholder + + +def profile_latency(device1d, it=3, ctype="a"): + latency = [] + for i in range(it): + nbytes = int(Byte << i) + (t, _) = profile(device1d, nbytes, ctype) + latency.append(t) + return min(latency) + + +def profile_bandwidth(device1d, maxbytes, ctype="a"): + (_, bandwidth) = profile(device1d, maxbytes, ctype) + return bandwidth + + +def profile_ab(rank, *args): + wsize = int(torch.cuda.device_count()) + device1d = args[0] + return_dict = args[1] + ctype = args[2] + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29020' + dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank) + + device = torch.device("cuda", rank) + max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device) + max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB)))) + + if rank == device1d[0]: + print(f"max_nbytes: {max_nbytes} B") + + alpha = profile_latency(device1d, it=5, ctype=ctype) + beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype) + + if rank == device1d[0]: + print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}") + return_dict[rank] = (alpha, beta) + + +def profile_alpha_beta(device1d): + assert torch.cuda.is_available() + assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count()) + + manager = mp.Manager() + return_dict = manager.dict() + ctype = "a" + mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count())) + return return_dict[device1d[0]] diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..158796befb312755ed92f77f7828557f55800e4c --- /dev/null +++ b/colossalai/engine/__init__.py @@ -0,0 +1,4 @@ +from ._base_engine import Engine +from .gradient_handler import * + +__all__ = ['Engine'] diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..146a296692275821ab38b84746835cb313f12e58 --- /dev/null +++ b/colossalai/engine/_base_engine.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Iterable +from torch.nn import Module +from torch.nn.modules.loss import _Loss + +from colossalai.logging import get_dist_logger +from torch import Tensor +from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook +from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule +from typing import Optional, Type +from colossalai.engine.gradient_handler import BaseGradientHandler +from colossalai.logging import get_dist_logger + + +class Engine: + """Basic engine class for training and evaluation. It runs a specific process method + :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. + It controls a iteration in training. + + Args: + model (``torch.nn.Module``): The neural network model. + optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters. + criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss. + gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward. + clip_grad_norm (float, optional): The norm of gradient clipping. + ophook_list (list): List of ophook. + verbose (bool): whether to display log info. + schedule (''BaseSchedule''): Runtime schedule. + + Examples: + >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training + >>> model = ... + >>> criterion = ... + >>> optimizer = ... + >>> train_dataloader = ... + >>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion) + >>> engine.train() + >>> for inputs, labels in train_dataloader + >>> # set gradients to zero + >>> engine.zero_grad() + >>> # run forward pass + >>> outputs = engine(inputs) + >>> # compute loss value and run backward pass + >>> loss = engine.criterion(outputs, labels) + >>> engine.backward(loss) + >>> # update parameters + >>> engine.step() + + The example of using Engine in training could be find in + `Training with engine and trainer `_. and + `Run resnet cifar10 with engine `_. + """ + + def __init__(self, + model: Module, + optimizer: "ColossalaiOptimizer", + criterion: Optional[_Loss] = None, + gradient_handlers: Optional[List[BaseGradientHandler]] = None, + clip_grad_norm: float = 0.0, + ophook_list: Optional[List[BaseOpHook]] = None, + verbose: bool = True, + schedule: Optional[BaseSchedule] = None): + self._model = model + self._optimizer = optimizer + self._criterion = criterion + self._clip_grad_norm = clip_grad_norm + self._verbose = verbose + self._logger = get_dist_logger() + + # state + self.training = True # default + + # build gradient handler + if gradient_handlers: + self._gradient_handlers = gradient_handlers + else: + self._gradient_handlers = [] + + if ophook_list is None: + self._ophook_list = [] + else: + self._ophook_list = ophook_list + + # build schedule + if schedule: + assert isinstance(schedule, BaseSchedule), \ + f'expected schedule to be of type BaseSchedule, but got {type(schedule)}' + self._schedule = schedule + else: + self._schedule = NonPipelineSchedule() + if self.uses_pipeline: + self._schedule.pre_processing(self) + + #register hook if any + if len(self._ophook_list) > 0: + register_ophooks_recursively(self._model, self._ophook_list) + + @property + def ophooks(self): + """show current activated ophooks""" + return self._ophook_list + + @property + def model(self): + """Model attached to the engine""" + return self._model + + @property + def optimizer(self): + """Optimizer attached to the engine""" + return self._optimizer + + @property + def criterion(self): + """Criterion attached to the engine""" + return self._criterion + + @property + def schedule(self): + """Schedule attached to the engine""" + return self._schedule + + @property + def uses_pipeline(self): + """show the pipeline parallel used or not""" + return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule)) + + def add_hook(self, ophook: Type[BaseOpHook]) -> None: + """add necessary hook""" + # whether this hook exist + for h in self._ophook_list: + if type(h) == type(ophook): + logger = get_dist_logger() + logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}") + self._ophook_list.append(ophook) + register_ophooks_recursively(self._model, self._ophook_list) + + def remove_hook(self, ophook: Type[BaseOpHook]) -> None: + """remove hook""" + logger = get_dist_logger() + logger.warning(f"removing hooks is currently not supported") + + def zero_grad(self): + """Set the gradient of parameters to zero + """ + self.optimizer.zero_grad() + + def step(self): + """Execute parameter update + """ + self._all_reduce_gradients() + self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) + return self.optimizer.step() + + def backward(self, loss: Tensor): + """Start backward propagation given the loss value computed by a loss function. + + Args: + loss (:class:`torch.Tensor`): Loss value computed by a loss function. + """ + ret = self.optimizer.backward(loss) + for ophook in self._ophook_list: + ophook.post_iter() + return ret + + def backward_by_grad(self, tensor, grad): + """Start backward propagation given the gradient of the output tensor. + + Args: + tensor (:class:`torch.Tensor`): Output tensor. + grad (:class:`torch.Tensor`): Gradient passed back to the output. + """ + ret = self.optimizer.backward_by_grad(tensor, grad) + for ophook in self._ophook_list: + ophook.post_iter() + return ret + + def __call__(self, *args, **kwargs): + """Run the forward step for the model. + + Returns: + Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`: Output of the model. + """ + return self.model(*args, **kwargs) + + def _all_reduce_gradients(self): + """Handles all-reduce operations of gradients across different parallel groups. + """ + for handler in self._gradient_handlers: + handler.handle_gradient() + + def execute_schedule(self, data_iter: Iterable, **kwargs): + """Run the forward, loss computation, and backward for the model. + Returns a tuple of (output, label, loss). + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). + """ + output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs) + return output, label, loss + + def train(self): + """Sets the model to training mode. + """ + self.training = True + self._model.train() + + def eval(self): + """Sets the model to evaluation mode. + """ + self.training = False + self._model.eval() diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/engine/gradient_accumulation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4585b9a2529c905b3595710e332d40fb32c449da --- /dev/null +++ b/colossalai/engine/gradient_accumulation/__init__.py @@ -0,0 +1,50 @@ +import torch.nn as nn +from typing import List +from colossalai.engine import BaseGradientHandler +from typing import Iterable +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler + +__all__ = [ + 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', + 'GradAccumGradientHandler' +] + + +def accumulate_gradient(model: nn.Module, + optimizer: Optimizer, + dataloader: Iterable, + accumulate_size: int, + gradient_handlers: List[BaseGradientHandler] = None, + lr_scheduler: _LRScheduler = None): + r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. + + Args: + model (:class:`torch.nn.Module`): your model object for gradient accumulation. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation. + dataloader (:class:`torch.utils.data.DataLoader` or iterable objects): + your dataloader object, would be called like iter(dataloader) + accumulate_size (int): the number of steps to accumulate gradients + gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]): + list of gradient handler objects. Default is None. + lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`): + your ``lr_scheduler`` object for gradient accumulation. Defaults to None. + + More details about `gradient_handlers` could be found in + `Gradient_handler `_. + + More details about `lr_scheduler` could be found + `lr_scheduler `_. and + `how to adjust learning rate `_. + """ + optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model) + dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size) + + if gradient_handlers is not None: + gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers] + + if lr_scheduler is not None: + lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) + + return optimizer, dataloader, gradient_handlers, lr_scheduler diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py new file mode 100644 index 0000000000000000000000000000000000000000..89c28c3be87abacfb04b99e3661da0bf15b224c7 --- /dev/null +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Union +import torch.nn as nn +from torch import Tensor +from typing import Iterable, Any, Tuple +from colossalai.nn.optimizer import ColossalaiOptimizer +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from colossalai.utils import conditional_context +from colossalai.engine import BaseGradientHandler + + +class GradAccumOptimizer(ColossalaiOptimizer): + """A wrapper for the optimizer to enable gradient accumulation by skipping the steps + before accumulation size is reached. + + Args: + optim (:class:`torch.optim.Optimizer`): Your optimizer object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + model (:class:`torch.nn.Module`): + Your model object to check if it is DistributedDataParallel for special handling of no_sync() context. + """ + + def __init__(self, optim: Optimizer, accumulate_size: int, model: nn.Module = None): + super().__init__(optim) + self.accumulate_size = accumulate_size + self.accumulate_step = 0 + + # handle pytorch ddp auto all reduce + self.model = model + self.is_torch_ddp = isinstance(self.model, DistributedDataParallel) + + def zero_grad(self, *args, **kwargs) -> None: + """ + Set all gradients to zero. + + Args: + *args: positional arguments for the optimizer wrapped + **kwargs: keyword arguments for the optimizer wrapped + """ + + if self.accumulate_step == 0: + self.optim.zero_grad(*args, **kwargs) + + def step(self, *args, **kwargs) -> None: + """ + Update the model parameters. + + Args: + *args: positional arguments for the optimizer wrapped + **kwargs: keyword arguments for the optimizer wrapped + """ + + if self.accumulate_step < self.accumulate_size: + return None + else: + self.accumulate_step = 0 + return self.optim.step(*args, **kwargs) + + def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None: + """ + Clip gradients by norm. + + Args: + model (:class:`torch.nn.Module`): a torch module instance + max_norm (float): the max norm for gradient clipping + """ + + if self.accumulate_step < self.accumulate_size: + pass + else: + self.optim.clip_grad_norm(model, max_norm) + + def backward(self, loss: Tensor) -> None: + """Execute backward pass. + + Args: + loss (:class:`torch.Tensor`): the loss value. + """ + + self.accumulate_step += 1 + + if self.is_torch_ddp: + no_sync = self.accumulate_step < self.accumulate_size + with conditional_context(self.model.no_sync(), enable=no_sync): + scaled_loss = loss / self.accumulate_size + self.optim.backward(scaled_loss) + else: + scaled_loss = loss / self.accumulate_size + self.optim.backward(scaled_loss) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: + """Execute backward pass given the gradients of the output. + + Args: + loss (:class:`torch.Tensor`): the loss value. + grad (:class:`torch.Tensor`): the output gradient. + """ + + self.accumulate_step += 1 + no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size + + if no_sync: + with self.model.no_sync(): + self.optim.backward_by_grad(tensor, grad) + else: + self.optim.backward_by_grad(tensor, grad) + + +class GradAccumDataloader: + """A wrapper for dataloader to enable gradient accumulation by dropping the last incomplete steps. + + Note: + The dataloader would drop the last incomplete steps for gradient accumulation. + For example, if a dataloader has 10 batches of data and accumulate size is 4. The model parameters will + be updated only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle. + Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader, + (e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches. + + Args: + dataloader (``Iterable``): Your dataloader object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + """ + + def __init__(self, dataloader: Iterable, accumulate_size: int) -> None: + self.dataloader = dataloader + self.consume_remain_data = not isinstance(dataloader, DataLoader) + self.steps_per_epoch = len(dataloader) - len(dataloader) % accumulate_size + + def __getattr__(self, __name: str) -> Any: + return getattr(self.dataloader, __name) + + def __len__(self) -> int: + return self.steps_per_epoch + + def __iter__(self) -> Iterable: + self._cur_step = 0 + self._dataiter = iter(self.dataloader) + return self + + def __next__(self) -> Union[Tensor, Tuple[Tensor]]: + if self._cur_step < self.steps_per_epoch: + self._cur_step += 1 + data = next(self._dataiter) + + if self._cur_step == self.steps_per_epoch and self.consume_remain_data: + # this is to handle non standard pytorch dataloader + # such as dali dataloader + while True: + try: + _ = next(self._dataiter) + except StopIteration: + break + return data + else: + raise StopIteration + + +class GradAccumLrSchedulerByStep(_LRScheduler): + """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps + before accumulation size is reached. + + Args: + lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`): + Your ``lr_scheduler`` object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + """ + + def __init__(self, lr_scheduler: _LRScheduler, accumulate_size: int) -> None: + self.lr_scheduler = lr_scheduler + self.accumulate_size = accumulate_size + self.accumulate_step = 0 + + @staticmethod + def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int) -> int: + """ + Computes the number of effective training iterations. An effective iteration is defined + as the the aggregation of iterations. For examples, if accumulate_size = 4, + then 4 iterations are considered as one effective iteration. + + Args: + dataloader (``Iterable``): Your dataloader object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + + """ + return len(dataloader) // accumulate_size + + def __getattr__(self, __name: str) -> Any: + return getattr(self.lr_scheduler, __name) + + def step(self, *args, **kwargs) -> None: + """ + Update the learning rate. + + Args: + *args: positional arguments for the lr scheduler wrapped. + **kwargs: keyword arguments for the lr scheduler wrapped. + """ + self.accumulate_step += 1 + if self.accumulate_step < self.accumulate_size: + pass + else: + self.accumulate_step = 0 + self.lr_scheduler.step(*args, **kwargs) + + def get_lr(self) -> Tensor: + """ + Compute the next learning rate. + + Returns: + Tensor: the upcoming learning rate. + """ + + return self.lr_scheduler.get_lr() + + def get_last_lr(self) -> Tensor: + """ + Returns the current learning rate. + + Returns: + Tensor: the current learning rate. + """ + + return self.lr_scheduler.get_last_lr() + + def print_lr(self, *args, **kwargs) -> None: + """ + Print he learning rate. + + Args: + *args: positional arguments for the lr scheduler wrapped. + **kwargs: keyword arguments for the lr scheduler wrapped. + """ + self.lr_scheduler.print_lr(*args, **kwargs) + + def state_dict(self) -> dict: + """ + Returns the states of the lr scheduler as dictionary. + + Returns: + dict: the states of the lr scheduler. + """ + return self.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """ + Load the states of the lr scheduler from a dictionary object. + + Returns: + dict: the states of the lr scheduler. + """ + self.lr_scheduler.load_state_dict(state_dict) + + +class GradAccumGradientHandler: + r"""A wrapper for the gradient handler to enable gradient accumulation by skipping the steps + before accumulation size is reached. + + Args: + grad_handler (:class:`colossalai.engine.BaseGradientHandler`): + Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`. + accumulate_size (int): The number of steps to accumulate gradients. + + More details about ``gradient_handlers`` could be found in + `Gradient_handler `_. + + """ + + def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None: + assert isinstance(grad_handler, BaseGradientHandler), \ + f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}' + self.grad_handler = grad_handler + self.accumulate_size = accumulate_size + self.accumulate_step = 0 + + def handle_gradient(self) -> None: + """ + Handle gradients reduction only in the last gradient accumulation step. + """ + + self.accumulate_step += 1 + if self.accumulate_step < self.accumulate_size: + pass + else: + self.accumulate_step = 0 + self.grad_handler.handle_gradient() diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6177da69ba5b19669575f5dc13de92e880b3573a --- /dev/null +++ b/colossalai/engine/gradient_handler/__init__.py @@ -0,0 +1,12 @@ +from ._base_gradient_handler import BaseGradientHandler +from ._data_parallel_gradient_handler import DataParallelGradientHandler +from ._zero_gradient_handler import ZeROGradientHandler +from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler +from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler +from ._moe_gradient_handler import MoeGradientHandler +from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler + +__all__ = [ + 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', + 'MoeGradientHandler', 'SequenceParallelGradientHandler' +] diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/engine/gradient_handler/_base_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c212359867d1d0e6ffbf551b787cf645e6a0d3da --- /dev/null +++ b/colossalai/engine/gradient_handler/_base_gradient_handler.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + + +class BaseGradientHandler(ABC): + """A basic helper class to handle all-reduce operations of gradients across different parallel groups + before optimization. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer): + self._model = model + self._optimizer = optimizer + + @abstractmethod + def handle_gradient(self): + """A method to accumulate gradients across different parallel groups. Users should + write their own functions or just use the functions in pre-defined subclasses. + """ + pass diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d113fc5164599327874f0755056aeb05b11b7637 --- /dev/null +++ b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -0,0 +1,26 @@ +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from ._base_gradient_handler import BaseGradientHandler +from ...context.parallel_mode import ParallelMode +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class DataParallelGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + # TODO: add memory buffer + if gpc.data_parallel_size > 1: + bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA)) diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b680f1fbfc42ba2f5abf9d1bf597d8f1beaebf73 --- /dev/null +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -0,0 +1,45 @@ +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from colossalai.utils.moe import get_moe_epsize_param_dict +from ._base_gradient_handler import BaseGradientHandler +from ...context.parallel_mode import ParallelMode +from .utils import bucket_allreduce +from colossalai.context.moe_context import MOE_CONTEXT + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + global_data = gpc.data_parallel_size + + if global_data > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=epsize_param_dict[ep_size], + group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..83f5c00cf2af62aad9274c3b60f3df43365e0196 --- /dev/null +++ b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +from collections import defaultdict + +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from ._base_gradient_handler import BaseGradientHandler + + +@GRADIENT_HANDLER.register_module +class PipelineSharedModuleGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in sub parallel groups. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among all sub pipeline parallel groups. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in sub pipeline parallel groups. + """ + if gpc.pipeline_parallel_size > 1: + # bucketize and all-reduce + buckets = defaultdict(lambda: defaultdict(list)) + # Pack the buckets. + for param in self._model.parameters(): + group = getattr(param, 'pipeline_shared_module_pg', None) + if param.requires_grad and group is not None and ( + (hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null()) + or param.grad is not None): + tp = param.data.type() + buckets[group][tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for group, group_buckets in buckets.items(): + for tp, bucket in group_buckets.items(): + grads = [ + param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data + for param in bucket + ] + coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..53a8ea935a42eb4bdd87d7194ab84c8422a232e1 --- /dev/null +++ b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -0,0 +1,25 @@ +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from ._base_gradient_handler import BaseGradientHandler +from ...context.parallel_mode import ParallelMode +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class SequenceParallelGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1: + bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP)) diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/engine/gradient_handler/_zero_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f85303e751846a56a0b80003b3ff023b41c55b34 --- /dev/null +++ b/colossalai/engine/gradient_handler/_zero_gradient_handler.py @@ -0,0 +1,20 @@ +from colossalai.registry import GRADIENT_HANDLER +from ._base_gradient_handler import BaseGradientHandler + + +@GRADIENT_HANDLER.register_module +class ZeROGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + This class is specialized with ZeRO optimization. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + self._optimizer.sync_grad() diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/engine/gradient_handler/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0801fe6cd50643acdd50deea7e4e7a71b403c80d --- /dev/null +++ b/colossalai/engine/gradient_handler/utils.py @@ -0,0 +1,29 @@ +import torch.distributed as dist +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from typing import Iterable + + +def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None): + # get communication world size + comm_size = dist.get_world_size(group) + # bucketize and all-reduce + buckets = {} + # Pack the buckets. + for param in param_list: + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + coalesced /= comm_size + + dist.all_reduce(coalesced, group=group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54170286e99b891b898ffeaaa19ca2e07405791b --- /dev/null +++ b/colossalai/engine/schedule/__init__.py @@ -0,0 +1,5 @@ +from ._base_schedule import BaseSchedule +from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape +from ._non_pipeline_schedule import NonPipelineSchedule + +__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..ba797bad977845ee74fce041d002a9d273afb90e --- /dev/null +++ b/colossalai/engine/schedule/_base_schedule.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + +import torch + +from typing import Iterable, Callable +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device + + +class BaseSchedule(ABC): + """A basic helper class to control the process of training or evaluation. + It mainly composes of forward_backward_step for gradient backward and + optimizer_step for parameters update. + For the convenience to enable FP16, we aggregate all codes that contain the + control of FP16 in class schedule. + + Args: + data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges them into data and label. + """ + + def __init__(self, data_process_func: Callable = None): + self.logger = get_dist_logger() + self.data_process_func = data_process_func + + @staticmethod + def _move_tensor(element): + if torch.is_tensor(element): + if not element.is_cuda: + return element.to(get_current_device()).detach() + return element + + def _move_to_device(self, data): + if isinstance(data, torch.Tensor): + data = data.to(get_current_device()) + elif isinstance(data, (list, tuple)): + data_to_return = [] + for element in data: + if isinstance(element, dict): + data_to_return.append({k: self._move_tensor(v) for k, v in element.items()}) + else: + data_to_return.append(self._move_tensor(element)) + data = data_to_return + elif isinstance(data, dict): + data = {k: self._move_tensor(v) for k, v in data.items()} + else: + raise TypeError( + f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + return data + + def _get_batch_size(self, data): + if isinstance(data, torch.Tensor): + return data.size(0) + elif isinstance(data, (list, tuple)): + if isinstance(data[0], dict): + return data[0][list(data[0].keys())[0]].size(0) + return data[0].size(0) + elif isinstance(data, dict): + return data[list(data.keys())[0]].size(0) + + def load_batch(self, data_iter, to_gpu=True): + """Loads a batch from data iterator. It returns the data and labels which are + already in the same GPU as where the model's. + + Args: + data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). + to_gpu (bool, optional): Whether the data should be moved to GPU + + Returns: + Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label). + """ + if data_iter is None: + raise RuntimeError('Dataloader is not defined.') + batch_data = next(data_iter) + + if to_gpu: + batch_data = self._move_to_device(batch_data) + self.batch_size = self._get_batch_size(batch_data) + return batch_data + + def pre_processing(self, engine): + """To perform actions before running the schedule. + """ + pass + + @abstractmethod + def forward_backward_step(self, + engine, + data_iter: Iterable, + forward_only: bool, + return_loss: bool = True, + return_output_label: bool = True): + """The process function over a batch of dataset for training or evaluation. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). + forward_only (bool): If True, the process won't include backward. + return_loss (bool, optional): If False, the loss won't be returned. + return_output_label (bool, optional): If False, the output and label won't be returned. + """ + pass + + @staticmethod + def _call_engine(engine, inputs): + if isinstance(inputs, torch.Tensor): + return engine(inputs) + elif isinstance(inputs, (list, tuple)): + return engine(*inputs) + elif isinstance(inputs, dict): + return engine(**inputs) + else: + TypeError( + f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}") + + @staticmethod + def _call_engine_criterion(engine, outputs, labels): + assert isinstance(outputs, + (torch.Tensor, list, tuple, + dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + if isinstance(labels, torch.Tensor): + labels = (labels,) + + if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)): + return engine.criterion(*outputs, *labels) + elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict): + return engine.criterion(*outputs, **labels) + elif isinstance(outputs, dict) and isinstance(labels, dict): + return engine.criterion(**outputs, **labels) + elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): + raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") + else: + raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \ + '(which is auto-converted to tuple), list, tuple, or dict, ' \ + 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)") diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..c62bfb7d7375f4d58d1ead3c244cb5aaadf804c1 --- /dev/null +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Iterable + +import torch +import inspect +from ._base_schedule import BaseSchedule +from colossalai.utils import conditional_context +from typing import Callable + + +class NonPipelineSchedule(BaseSchedule): + """A helper schedule class for no pipeline parallelism running environment. + During one process, it loads a batch of dataset and feeds it to the model. + After getting the output and calculating the loss, it will use :meth:`step` + to update the parameters if it is in training mode. + + Args: + data_process_func (Callable, optional): The preprocessing function which receives a batch of data + and returns a tuple in the form of (data, label). + and it will be executed in load_batch. + + Example: + # this shows an example of customized data_process_func + def data_process_func(dataloader_output): + item1, item2, item3 = dataloader_output + data = (item1, item2) + label = item3 + return data, label + """ + + def __init__(self, data_process_func: Callable = None): + # check that non-pipeline schedule data process func only takes in one parameter + # which is the batch data + + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 1, \ + 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ + 'which is a tuple of tensors for the current batch, ' \ + 'i.e. data_process_func(dataloader_output).' + + super().__init__(data_process_func) + + def forward_backward_step(self, + engine, + data_iter: Iterable, + forward_only: bool = False, + return_loss: bool = True, + return_output_label: bool = True): + """The process function that loads a batch of dataset and feeds it to the model. + The returned labels and loss will None if :attr:`return_loss` is False. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + If True, the model is run for the forward pass, else back propagation will be executed. + return_loss (bool, optional): Loss will be returned if True. + return_output_label (bool, optional): Output and label will be returned if True. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + assert forward_only or return_loss, \ + "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + batch_data = self.load_batch(data_iter) + if self.data_process_func: + data, label = self.data_process_func(batch_data) + else: + # if not batch data process func is given, + # then we regard the batch data as a simple tuple of (data, label) + data, label = batch_data + + # forward + with conditional_context(torch.no_grad(), enable=forward_only): + output = self._call_engine(engine, data) + if return_loss: + loss = self._call_engine_criterion(engine, output, label) + + if not forward_only: + engine.backward(loss) + + if return_output_label: + if return_loss: + return output, label, loss + else: + return output, label, None + else: + if return_loss: + return None, None, loss + else: + return None, None, None diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..97571fa024baa7f1f294373fcaad8dcf75c3f00e --- /dev/null +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -0,0 +1,831 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +from typing import Callable, List, Tuple, Union + +import colossalai.communication as comm +import torch.cuda +from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import switch_virtual_pipeline_parallel_rank +from colossalai.utils.cuda import get_current_device + +from ._base_schedule import BaseSchedule + + +def get_tensor_shape(): + if hasattr(gpc.config, 'TENSOR_SHAPE'): + return gpc.config.TENSOR_SHAPE + + if not gpc.is_initialized(ParallelMode.PIPELINE): + return None + + if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr( + gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): + if gpc.is_initialized(ParallelMode.DATA): + dp_size = gpc.get_world_size(ParallelMode.DATA) + else: + dp_size = 1 + if gpc.is_initialized(ParallelMode.SEQUENCE): + seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) + else: + seq_size = 1 + + tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, + gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE) + return tensor_shape + else: + return None + + +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) + else: + raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label + + +class PipelineSchedule(BaseSchedule): + """A helper schedule class for pipeline parallelism running environment. + It uses non-interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + + """ + + def __init__(self, + num_microbatches, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False): + + # we need to make sure that the signature of the data_process_func is valid + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 2, \ + 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ + 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ + 'i.e. data_process_func(stage_output, dataloader_output).' + + super().__init__(data_process_func=data_process_func) + + assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' + + self.num_microbatches = num_microbatches + self.dtype = torch.float + assert not isinstance(tensor_shape, + int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + if tensor_shape is None: + self.tensor_shape = tensor_shape + elif isinstance(tensor_shape, torch.Size): + self.tensor_shape = tensor_shape + else: + self.tensor_shape = torch.Size(tensor_shape) + self.scatter_gather_tensors = False + if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: + self.scatter_gather_tensors = scatter_gather_tensors + self._logger = get_dist_logger() + + # cache for the batch data + self.batch_data = None + + def load_batch(self, data_iter): + # Pipeline schedule just puts data in memory + batch_data = super().load_batch(data_iter, to_gpu=False) + self.microbatch_offset = 0 + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + self.batch_data = batch_data + + def _get_data_slice(self, data, offset): + if isinstance(data, torch.Tensor): + return data[offset:offset + self.microbatch_size] + elif isinstance(data, (list, tuple)): + data_dict = {} + for element in data: + if isinstance(element, dict): + data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) + elif data_dict: + data_dict['label'] = element[offset:offset + self.microbatch_size] + if data_dict: + return data_dict + return [val[offset:offset + self.microbatch_size] for val in data] + elif isinstance(data, dict): + return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def load_micro_batch(self): + mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset) + self.microbatch_offset += self.microbatch_size + return self._move_to_device(mciro_batch_data) + + def pre_processing(self, engine): + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + # TODO: remove this after testing new zero with pipeline parallelism + model = engine.model + if isinstance(model, NaiveAMPModel): + self.dtype = torch.half + model = model.model + if isinstance(model, ShardedModelV2): + self.dtype = torch.half + model = model.module + # sig = inspect.signature(model.forward) + # for p in sig.parameters.values(): + # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + + @staticmethod + def _call_engine(model, data): + if data is not None: + if isinstance(data, torch.Tensor): + return model(data) + elif isinstance(data, (list, tuple)): + return model(*data) + elif isinstance(data, dict): + stage_output = None + if 'stage_output' in data: + stage_output = data.pop('stage_output') + if stage_output is None: + return model(**data) + elif isinstance(stage_output, torch.Tensor): + return model(stage_output, **data) + elif isinstance(stage_output, (tuple, list)): + return model(*stage_output, **data) + else: + raise TypeError( + f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}" + ) + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def _get_actual_forward_func(self, module): + if isinstance(module, NaiveAMPModel): + sig = inspect.signature(module.model.forward) + elif hasattr(module, 'colo_attr'): + sig = inspect.signature(module.module.forward) + else: + sig = inspect.signature(module.forward) + return sig + + def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model): + if self.data_process_func: + # use customized function to get data and label + data, label = self.data_process_func(stage_output, micro_batch_data) + else: + if isinstance(micro_batch_data, (tuple, list)): + if gpc.is_first_rank(ParallelMode.PIPELINE): + # for the first stage, we use the data from the + # dataloader output by default + data, label = micro_batch_data + else: + # for non-first stage, we use the output passed + # by the previous as the model input + data = stage_output + _, label = micro_batch_data + elif isinstance(micro_batch_data, dict): + data = {} + data['stage_output'] = stage_output + if 'label' in micro_batch_data: + label = micro_batch_data.pop('label') + else: + label = None + load_data = micro_batch_data + data.update(load_data) + return data, label + + def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_obj is used. + Returns output tensor. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. + return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. + return_output_label (bool, optional): Whether returns output labels. + accum_loss (optional): Where accumulated loss stores. + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. + """ + micro_batch_data = self.load_micro_batch() + + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model) + + output_obj = self._call_engine(engine.model, data) + + if gpc.is_last_rank(ParallelMode.PIPELINE): + if return_output_label: + return_tensors.append((output_obj, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) + return loss_reduced + else: + # forward only, it's useless since backward is not needed + return output_obj + else: + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + ) + return output_obj + + def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): + """Backward step through the passed-in output tensor. If it is the last stage, the + output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. + Returns the gradients with respect to the input tensor (None if first stage). + This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. + output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. + output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor. + """ + + # Retain the grad on the input_obj. + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj.retain_grad() + else: + for in_tensor in input_obj: + if in_tensor is not None: + in_tensor.retain_grad() + # Backward pass. + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj_grad = input_obj.grad + else: + input_obj_grad = [] + for in_tensor in input_obj: + input_obj_grad.append(in_tensor.grad) + + return input_obj_grad + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + Returns a tuple with losses if the last stage, an empty tuple otherwise. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + num_warmup_microbatches = \ + (gpc.get_world_size(ParallelMode.PIPELINE) + - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + if not forward_only: + input_objs = [] + output_objs = [] + return_tensors = [] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + # Used for tensor meta information communication + ft_shapes = self.tensor_shape + bt_shapes = None + fs_checker = self.tensor_shape is None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + if not gpc.is_first_rank(ParallelMode.PIPELINE): + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(output_obj, torch.Tensor): + bt_shapes = output_obj.shape + else: + bt_shapes = [] + for out_tensor in output_obj: + bt_shapes.append(out_tensor.shape) + fs_checker = comm.send_obj_meta(output_obj, fs_checker) + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + if not gpc.is_first_rank(ParallelMode.PIPELINE): + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if forward_only: + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + if not last_iteration: + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + else: + output_obj_grad = comm.send_forward_recv_backward(output_obj, + bt_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + else: + input_obj = comm.send_backward_recv_forward(input_obj_grad, + ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = comm.recv_backward(bt_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss + + +class InterleavedPipelineSchedule(PipelineSchedule): + + def __init__(self, + num_microbatches: int, + num_model_chunks: int, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False): + """A helper schedule class for pipeline parallelism running environment. + It uses interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + num_model_chunks (int): The number of model chunks. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + """ + assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ + 'num_microbatches must be an integer multiple of pipeline parallel world size' + assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ + f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' + super().__init__(num_microbatches, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors) + gpc.set_virtual_pipeline_parallel_size(num_model_chunks) + gpc.set_virtual_pipeline_parallel_rank(0) + self.num_model_chunks = num_model_chunks + + def pre_processing(self, engine): + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + if isinstance(engine.model, ShardedModelV2): + self.dtype = torch.half + elif isinstance(engine.model[0], NaiveAMPModel): + self.dtype = torch.half + for model in engine.model: + if isinstance(model, NaiveAMPModel): + model = model.model + sig = inspect.signature(model.forward) + for p in sig.parameters.values(): + assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + + def load_batch(self, data_iter): + super().load_batch(data_iter) + # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + + def load_micro_batch(self, model_chunk_id): + data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return self._move_to_device(data) + + def _forward_step(self, + engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=True, + accum_loss=None): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_obj is used. + Returns output tensor. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + model_chunk_id (int): The id of model chunks. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. + return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. + return_output_label (bool, optional): Whether returns output labels. + accum_loss (optional): Where accumulated loss stores. + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. + """ + micro_batch_data = self.load_micro_batch(model_chunk_id) + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, + engine.model[model_chunk_id]) + + output_obj = self._call_engine(engine.model[model_chunk_id], data) + + if gpc.is_pipeline_last_stage(): + if return_output_label: + return_tensors.append((output_obj, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) + return loss_reduced + else: + # forward only, it's useless since backward is not needed + return output_obj + else: + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + ) + return output_obj + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + The loss would be returned only in the last stage. + """ + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + model = engine.model + input_objs = [[] for _ in range(len(model))] + output_objs = [[] for _ in range(len(model))] + return_tensors = [] + if not forward_only: + output_obj_grads = [[] for _ in range(len(model))] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Used for obj meta information communication + input_obj_shapes = [self.tensor_shape for _ in range(len(model))] + output_obj_shapes = [None for _ in range(len(model))] + send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] + + pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + num_microbatches = self.num_microbatches * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if self.num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = \ + (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_microbatches_remaining = \ + num_microbatches - num_warmup_microbatches + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = (num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def _forward_step_helper(microbatch_id): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + # forward step + if gpc.is_pipeline_first_stage(): + if len(input_objs[model_chunk_id]) == \ + len(output_objs[model_chunk_id]): + input_objs[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id][-1] + output_obj = self._forward_step(engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + output_objs[model_chunk_id].append(output_obj) + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_objs[model_chunk_id].pop() + output_objs[model_chunk_id].pop() + + return output_obj + + def _backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + if gpc.is_pipeline_last_stage(): + if len(output_obj_grads[model_chunk_id]) == 0: + output_obj_grads[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + output_obj_grad = output_obj_grads[model_chunk_id].pop(0) + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + return input_obj_grad + + # Run warmup forward passes. + gpc.set_virtual_pipeline_parallel_rank(0) + if not gpc.is_pipeline_first_stage(): + input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) + input_objs[0].append( + comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) + + for k in range(num_warmup_microbatches): + model_chunk_id = get_model_chunk_id(k, forward=True) + output_obj = _forward_step_helper(k) + if not gpc.is_pipeline_last_stage(): + if isinstance(output_obj, torch.Tensor): + output_obj_shapes[model_chunk_id] = output_obj.shape + else: + output_obj_shapes[model_chunk_id] = [] + for out_tensor in output_obj: + output_obj_shapes[model_chunk_id].append(out_tensor.shape) + send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, + send_tensor_shape_flags[model_chunk_id]) + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if gpc.is_pipeline_last_stage(): + output_obj = None + + with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): + if not gpc.is_pipeline_first_stage(): + input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( + input_obj_shapes[next_forward_model_chunk_id]) + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_obj_grad = None + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None + input_obj, output_obj_grad = \ + comm.send_forward_backward_recv_forward_backward( + output_obj, input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grads[num_model_chunks - 1].append(output_obj_grad) + else: + input_obj = \ + comm.send_forward_recv_forward( + output_obj, + input_shape, + recv_prev=recv_prev, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + input_objs[next_forward_model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + output_obj = _forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_obj_grad = _backward_step_helper(backward_k) + + # Send output_obj and input_obj_grad, receive input_obj + # and output_obj_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set obj to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) + if gpc.is_pipeline_last_stage(): + output_obj = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) + if gpc.is_pipeline_first_stage(): + input_obj_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), + forward=False) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + # Communicate objs. + input_obj, output_obj_grad = \ + comm.send_forward_backward_recv_forward_backward( + output_obj, input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Put input_obj and output_obj_grad in data structures in the + # right location. + if recv_prev: + input_objs[next_forward_model_chunk_id].append(input_obj) + if recv_next: + output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if all_warmup_microbatches: + output_obj_grads[num_model_chunks - 1].append( + comm.recv_backward(output_obj_shapes[num_model_chunks - 1], + scatter_gather_tensors=self.scatter_gather_tensors)) + for k in range(num_microbatches_remaining, num_microbatches): + input_obj_grad = _backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (num_microbatches - 1): + recv_next = False + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + output_obj_grads[next_backward_model_chunk_id].append( + comm.send_backward_recv_backward(input_obj_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/engine/schedule/_pipeline_schedule_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..50a87aafad026f7439d6082303999791c1a11796 --- /dev/null +++ b/colossalai/engine/schedule/_pipeline_schedule_v2.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Tuple, Iterable + +from colossalai import engine +import colossalai.communication.p2p_v2 as comm +import torch.cuda +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils.cuda import get_current_device + +from ._pipeline_schedule import PipelineSchedule + + +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) + else: + raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label + + +class PipelineScheduleV2(PipelineSchedule): + """Derived class of PipelineSchedule, the only difference is that + forward_backward_step is reconstructed with p2p_v2 + + Args: + num_microbatches (int): The number of microbatches. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + + """ + + def forward_backward_step(self, + engine: engine.Engine, + data_iter: Iterable, + forward_only=False, + return_loss=True, + return_output_label=True) -> Tuple[torch.Tensor]: + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + Returns a tuple with losses if the last stage, an empty tuple otherwise. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + + # num_warmup_microbatches is the step when not all the processers are working + num_warmup_microbatches = \ + (gpc.get_world_size(ParallelMode.PIPELINE) + - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + # local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if not forward_only: + input_objs = [] + output_objs = [] + return_tensors = [] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + input_obj = comm.recv_forward() + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + + comm.send_forward(output_obj) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_obj = comm.recv_forward() + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if forward_only: + comm.send_forward(output_obj) + + if not last_iteration: + input_obj = comm.recv_forward() + + else: + # TODO adjust here + comm.send_forward(output_obj) + output_obj_grad = comm.recv_backward() + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + comm.send_backward(input_obj_grad) + else: + input_obj = comm.recv_forward() + comm.send_backward(input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = comm.recv_backward() + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + comm.send_backward(input_obj_grad) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d39fa579943d7a4f3cccd7ea7f6e2a20613373c8 --- /dev/null +++ b/colossalai/fx/__init__.py @@ -0,0 +1,4 @@ +from ._compatibility import compatibility, is_compatible_with_meta +from .graph_module import ColoGraphModule +from .passes import MetaInfoProp, metainfo_trace +from .tracer import ColoTracer, meta_trace, symbolic_trace diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..1264032703019de5b1920a83e4da5d3e395537a1 --- /dev/null +++ b/colossalai/fx/_compatibility.py @@ -0,0 +1,46 @@ +from typing import Callable + +import torch + +try: + from . import _meta_registrations + META_COMPATIBILITY = True +except: + META_COMPATIBILITY = False + + +def compatibility(is_backward_compatible: bool = False) -> Callable: + """A decorator to make a function compatible with different versions of PyTorch. + + Args: + is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False. + + Returns: + Callable: The decorated function + """ + + def decorator(func): + if META_COMPATIBILITY: + return func + else: + if is_backward_compatible: + return func + else: + + def wrapper(*args, **kwargs): + raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') + + return wrapper + + return decorator + + +def is_compatible_with_meta() -> bool: + """Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx` + modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its + experimental counterparts. + + Returns: + bool: The meta compatibility + """ + return META_COMPATIBILITY diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py new file mode 100644 index 0000000000000000000000000000000000000000..d614219dbef07e90e8a11a5c273019cb2928cd40 --- /dev/null +++ b/colossalai/fx/_meta_registrations.py @@ -0,0 +1,472 @@ +# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +# should be activated for PyTorch version 1.12.0 and below +# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +# for more meta_registrations + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch.utils._pytree import tree_map + +aten = torch.ops.aten + +meta_lib = torch.library.Library("aten", "IMPL", "Meta") + +meta_table = {} + + +def register_meta(op, register_dispatcher=True): + + def wrapper(f): + + def add_func(op): + meta_table[op] = f + if register_dispatcher: + name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + try: + meta_lib.impl(name, f) + except: + pass + + tree_map(add_func, op) + return f + + return wrapper + + +# ============================== Convolutions ====================================== +# https://github.com/pytorch/pytorch/pull/79834 +@register_meta(aten.convolution.default) +def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, +): + + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 + + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 + + def calc_conv_nd_return_shape( + dims: torch.Size, + kernel_size: torch.Size, + stride: Union[List[int], int], + padding: Union[List[int], int], + dilation: Union[List[int], int], + output_padding: Optional[Union[List[int], int]] = None, + ): + ret_shape = [] + if isinstance(stride, int): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) + + if isinstance(padding, int): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) + + if isinstance(dilation, int): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) + + output_padding_list: Optional[List[int]] = None + if output_padding: + if isinstance(output_padding, int): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding + + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + )) + else: + ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) + return ret_shape + + def pick_memory_format(): + if input_tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] + + shape_out = calc_conv_nd_return_shape( + dims, + kernel_size, + stride, + padding, + dilation, + output_padding, + ) + + else: + out_channels = weight.shape[0] + if weight.shape[1] != input_tensor.shape[1] / groups: + raise RuntimeError("Invalid channel dimensions") + shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) + out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) + mem_fmt = pick_memory_format() + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + return out + + +@register_meta(aten._convolution.default) +def meta_conv_1( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args +): + out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) + return out + + +@register_meta(aten.convolution_backward.default) +def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, + padding, dilation, transposed, output_padding, groups, output_mask): + return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta_adaptive_avg_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, +): + grad_input = torch.empty_like(input) + return grad_input + + +# ================================ RNN ============================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn.default) +def meta_cuda_rnn( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + weight_buf: torch.Tensor, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, +): + if cx is not None: + return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx) + else: + return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta') + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn_backward.default) +def meta_cudnn_rnn_backward(input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs): + print(input, weight, hx, cx) + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_hx = torch.empty_like(hx) + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + return grad_input, grad_weight, grad_hx, grad_cx + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp +# ============================== Activations ======================================= +@register_meta(aten.relu.default) +def meta_relu(input: torch.Tensor): + return torch.empty_like(input) + + +@register_meta(aten.prelu.default) +def meta_prelu(input: torch.Tensor, weight: torch.Tensor): + return torch.empty_like(input) + + +@register_meta(aten.hardswish.default) +def meta_hardswish(input: torch.Tensor): + return torch.empty_like(input) + + +@register_meta(aten.hardtanh.default) +def meta_hardtanh(input: torch.Tensor, min, max): + return torch.empty_like(input) + + +@register_meta(aten.hardswish_backward.default) +def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): + grad_in = torch.empty_like(input) + return grad_in + + +@register_meta(aten.hardtanh_backward.default) +def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int): + grad_in = torch.empty_like(input) + return grad_in + + +# ============================== Normalization ===================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.native_batch_norm.default) +def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((n_input), device='meta') + running_var = torch.empty((n_input), device='meta') + return output, running_mean, running_var + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.native_batch_norm_backward.default) +def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, + save_invstd, train, eps, output_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(weight) + return dX, dgamma, dbeta + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.cudnn_batch_norm.default) +def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((n_input), device='meta') + running_var = torch.empty((n_input), device='meta') + reserve = torch.empty((0), dtype=torch.uint8, device='meta') + return output, running_mean, running_var, reserve + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +# NB: CuDNN only implements the backward algorithm for batchnorm +# in training mode (evaluation mode batchnorm has a different algorithm), +# which is why this doesn't accept a 'training' parameter. +@register_meta(aten.cudnn_batch_norm_backward.default) +def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, eps, reserve): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(weight) + return dX, dgamma, dbeta + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp +@register_meta(aten.native_layer_norm.default) +def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): + bs = input.size(0) + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((bs, n_input, 1), device='meta') + running_var = torch.empty((bs, n_input, 1), device='meta') + return output, running_mean, running_var + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp +@register_meta(aten.native_layer_norm_backward.default) +def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, + grad_input_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(bias) + return dX, dgamma, dbeta + + +# ================================== Misc ========================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +@register_meta(aten.roll.default) +def meta_roll(input: torch.Tensor, shifts, dims): + return input + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp +@register_meta(aten._local_scalar_dense.default) +def meta_local_scalar_dense(self: torch.Tensor): + return 0 + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp +@register_meta(aten.where.self) +def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): + result_type = torch.result_type(self, other) + return torch.empty_like(self, dtype=result_type) + + +@register_meta(aten.index.Tensor) +def meta_index_Tensor(self, indices): + assert indices, "at least one index must be provided" + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: List[Optional[torch.Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + assert index.dtype in [torch.long, torch.int8, torch.bool],\ + "tensors used as indices must be long, byte or bool tensors" + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" + for j in range(index.ndim): + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + result.append(nonzero.select(1, j)) + else: + result.append(index) + else: + result.append(index) + indices = result + assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + # expand_outplace + import torch._refs as refs + + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) + + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 + else: + if index is not None: + break + else: + has_contiguous_subspace = True + + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices + + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: List[int] = [] + after_shape: List[int] = [] + replacement_shape: List[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + return self.new_empty(before_shape + replacement_shape + after_shape) + + +# ============================== Embedding ========================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp +@register_meta(aten.embedding_dense_backward.default) +def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, + scale_grad_by_freq): + return torch.empty((num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout) + + +# ============================== Dropout =========================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp +@register_meta(aten.native_dropout.default) +def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): + # notice that mask is bool + output = torch.empty_like(input) + mask = torch.empty_like(input, dtype=torch.bool) + return output, mask + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp +@register_meta(aten.native_dropout_backward.default) +def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): + return torch.empty_like(grad) diff --git a/colossalai/fx/codegen/__init__.py b/colossalai/fx/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..794692f511537db5a8a54588f15ecafb41242ea2 --- /dev/null +++ b/colossalai/fx/codegen/__init__.py @@ -0,0 +1 @@ +from .activation_checkpoint_codegen import * diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..492ebf918a9c66fb8fd5aa251affe61cd4bfbc5e --- /dev/null +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -0,0 +1,1058 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch + +import colossalai + +try: + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = True +except: + from torch.fx.graph import ( + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_args, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = False + +if CODEGEN_AVAILABLE: + __all__ = ['ActivationCheckpointCodeGen'] +else: + __all__ = ['python_code_with_activation_checkpoint'] + + +def _gen_saved_tensors_hooks(): + """ + Generate saved tensors hooks + """ + + pack_hook = """def pack_hook_input(self, x): + if getattr(x, "offload", False): + return (x.device, x.cpu()) + else: + return x + +def pack_hook_no_input(self, x): + if getattr(x, "offload", True): + return (x.device, x.cpu()) + else: + return x +""" + + unpack_hook = """def unpack_hook(self, packed): + if isinstance(packed, tuple): + device, tensor = packed + return tensor.to(device) + else: + return packed +""" + + return pack_hook, unpack_hook + + +def _gen_save_tensors_hooks_context(offload_input=True) -> str: + """Generate customized saved_tensors_hooks + Args: + offload_input (bool, optional): whether we need offload input, if offload_input=False, + we will use self.pack_hook_no_input instead. Defaults to True. + Returns: + str: generated context + """ + + if offload_input: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n" + else: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n" + return context + + +def _gen_save_on_cpu_context(): + """ + Generate save on cpu context + """ + + context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n" + return context + + +def _find_input_and_output_nodes(nodes: List[Node]): + """ + Find the input and output node names which are not found in the given list of nodes. + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + node_repr = repr(input_node) + if input_node not in nodes and node_repr not in input_nodes: + input_nodes.append(node_repr) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + node_repr = repr(node) + if output_node not in nodes and node_repr not in output_nodes: + output_nodes.append(node_repr) + + return input_nodes, output_nodes + + +def _find_ckpt_regions(nodes: List[Node]): + """ + Find the checkpoint regions given a list of consecutive nodes. The outputs will be list + of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_nodes = [] + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if 'activation_checkpoint' in node.meta: + act_ckpt_label = node.meta['activation_checkpoint'] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and not 'activation_checkpoint' in node.meta: + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + return ckpt_regions + + +def _find_offload_regions(nodes: List[Node]): + """This function is to find the offload regions + In pofo algorithm, during annotation, we will annotate the offload region with the + list in the form of [idx, offload_input, offload_bar]. idx indicates the offload + region's index, offload_input is a bool type indicates whether we need to offload + the input, offload_bar is a bool type indicates whether we need to offload all the + intermediate x_bars of this region. + """ + offload_regions = [] + offload_labels = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): + act_offload_label = node.meta['activation_offload'] + + if current_region == None: + current_region = act_offload_label + start = idx + offload_labels.append(act_offload_label) + + if act_offload_label != current_region: + assert start != -1 + offload_regions.append((start, idx - 1)) + offload_labels.append(act_offload_label) + current_region = act_offload_label + start = idx + end = -1 + + else: + if current_region is not None: + end = idx - 1 + assert start != -1 and end != -1 + offload_regions.append((start, end)) + start = end = -1 + current_region = None + + else: + pass + + return offload_regions, offload_labels + + +def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: + """ + Generate the checkpoint function definition + """ + return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" + + +def _gen_ckpt_output(output_vars: List[str]) -> str: + """ + Generate the return statement for checkpoint region + """ + return f"return {', '.join(output_vars)}" + + +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): + """ + Generate the checkpoint function call code text + """ + outputs = ', '.join(output_vars) + inputs = ', '.join(input_vars) + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + + +def _end_of_ckpt(node: Node, check_idx: int) -> bool: + """Check if the node could end the ckpt region + Args: + node (Node): torch.fx.Node + check_idx (int): the index of checkpoint level for + nested checkpoint + Returns: + bool + """ + if 'activation_checkpoint' in node.meta: + if isinstance(node.meta['activation_checkpoint'], list): + return node.meta['activation_checkpoint'][check_idx] == None + else: + return False + else: + return True + + +def _find_nested_ckpt_regions(nodes, check_idx=0): + """ + Find the nested checkpoint regions given a list of consecutive nodes. The outputs + will be list of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if 'activation_checkpoint' in node.meta: + if isinstance(node.meta['activation_checkpoint'], int): + act_ckpt_label = node.meta['activation_checkpoint'] + else: + act_ckpt_label = node.meta['activation_checkpoint'][check_idx] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and _end_of_ckpt(node, check_idx): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + + if current_region is not None: + end = len(nodes) - 1 + ckpt_regions.append((start, end)) + return ckpt_regions + + +def emit_ckpt_func(body, + ckpt_func, + node_list: List[Node], + emit_node_func, + delete_unused_value_func, + level=0, + in_ckpt=False): + """Emit ckpt fuction in nested way + Args: + body: forward code, in recursive calls, this part will be checkpoint + functions code + ckpt_func: checkpoint functions code, in recursive calls, this part + will be a buffer + node_list (List[Node]): list of torch.fx.Node + emit_node_func: function to emit a node + delete_unused_value_func: function to delete unused value + level (int, optional): checkpoint level. Defaults to 0. + in_ckpt (bool, optional): indicates wether the func is in recursive + call. Defaults to False. + """ + inputs, outputs = _find_input_and_output_nodes(node_list) + + # if the current checkpoint function use int as label, using old generation method + if isinstance(node_list[0].meta['activation_checkpoint'], int): + label = node_list[0].meta['activation_checkpoint'] + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = node_list[0].meta.get('activation_offload', False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + usage += "\n" + body.append(usage) + + # use nested ckpt function codegen + else: + # label given by each layer, e.g. if you are currently at level [0, 1, 1] + # the label will be '0_1_1' + label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + + # if there is more level to fetch + if level + 1 < len(node_list[0].meta['activation_checkpoint']): + ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # use ckpt_func_buffer to store nested checkpoint functions + ckpt_func_buffer = [] + node_idx = 0 + while 1: + if node_idx >= len(node_list): + break + + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, + delete_unused_value_func, level + 1, True) + node_idx += len(ckpt_node_list) + + else: + node = node_list[node_idx] + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + node_idx += 1 + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func += ckpt_func_buffer + activation_offload = node_list[0].meta.get('activation_offload', False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + # last level + else: + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = node_list[0].meta.get('activation_offload', False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + +def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + ckpt_regions = _find_nested_ckpt_regions(nodes, 0) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # this flag is to prevent repeated insert of save tensors + # hooks definition in ckpt_func + is_hook_inserted = False + node_idx = 0 + while 1: + # break if we finish the processing all the nodes + if node_idx >= len(node_list): + break + + # process ckpt_regions + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) + node_idx += len(ckpt_node_list) + + # process node in forward function + else: + node = node_list[node_idx] + + if node_idx in offload_starts: + offload_label = offload_labels[offload_starts.index(node_idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + if within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if node_idx in offload_ends: + within_offload_region = False + + node_idx += 1 + + +def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + # find the activation checkpoint regions + ckpt_regions = _find_ckpt_regions(nodes) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + input_vars = [] + output_vars = [] + within_ckpt_region = False + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # use this variable to avoid inserting hook functions + # to ckpt_func repeatedly + is_hook_inserted = False + + # find the input and output var names for each region + for idx, (start, end) in enumerate(ckpt_regions): + ckpt_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) + input_vars.append(inputs) + output_vars.append(outputs) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # append code text to body + for idx, node in enumerate(node_list): + # if this is the first node of the ckpt region + # append the ckpt function defition + if idx in start_idx: + label = start_idx.index(idx) + ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) + ckpt_func.append(f'{ckpt_fn_def}\n') + within_ckpt_region = True + + if idx in offload_starts: + offload_label = offload_labels[offload_starts.index(idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + # NOTE: currently we separate body and ckpt_func definition + if within_ckpt_region: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + elif within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if idx in end_idx: + # if this is the last node of the ckpt region + # generate return statement + label = end_idx.index(idx) + return_statement = _gen_ckpt_output(output_vars[label]) + return_statement = f' {return_statement}\n\n' + ckpt_func.append(return_statement) + + # we need to check if the checkpoint need to offload the input + start_node_idx = start_idx[label] + if 'activation_offload' in node_list[start_node_idx].meta: + activation_offload = node_list[start_node_idx].meta['activation_offload'] + else: + activation_offload = False + + # we need to check if the checkpoint need use_reentrant=False + use_reentrant = True + non_leaf_input = 0 + for var in input_vars[label]: + input_node = next(item for item in node_list if item.name == var) + if input_node.op != "placeholder": + non_leaf_input = 1 + for user in input_node.users: + if 'activation_checkpoint' in user.meta: + if user.meta['activation_checkpoint'] == label: + if user.op == "call_module": + if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): + use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace + + elif user.op == "call_function": + if "inplace" in user.kwargs: + use_reentrant = not user.kwargs["inplace"] + + # if all the inputs are leaf nodes, we need to set use_reentrant = False + if not non_leaf_input: + use_reentrant = False + + # generate checkpoint function call in a new line + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) + usage += '\n' + body.append(usage) + within_ckpt_region = False + + if idx in offload_ends: + within_offload_region = False + + +if CODEGEN_AVAILABLE: + + class ActivationCheckpointCodeGen(CodeGen): + + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + We call this for names that reference objects external to the + Graph, like functions or types. + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = ''.join(ckpt_func) + prologue + prologue = prologue + + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + fn_code = f""" +{wrap_stmts} +{prologue} +{code}""" + return PythonCode(fn_code, globals_) + +else: + + def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode: + """ + This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate + code for activation checkpoint. + """ + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + We call this for names that reference objects external to the + Graph, like functions or types. + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + # This is a generic type, e.g. typing.List[torch.Tensor] + if hasattr(o, '__origin__'): + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + return f'{origin_typename}[{",".join(args)}]' + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + if self._pytree_info is None: + body.append(f'return {repr(node.args[0])}') + else: + body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + if self._pytree_info is not None: + orig_args = self._pytree_info.orig_args + has_orig_self = (orig_args[0] == 'self') + if has_orig_self: + free_vars.insert(0, 'self') + if len(free_vars) > 0: # pytree has placeholders in it + body.insert( + 0, + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + else: + orig_args = free_vars + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + ckpt_func = ''.join(ckpt_func) + + # If the original function didn't have self as its first argument, we + # would have added it. + if len(orig_args) == 0 or orig_args[0] != 'self': + orig_args.insert(0, 'self') + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + fn_code = f""" +{wrap_stmts} +{ckpt_func} +def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: +{code}""" + return PythonCode(fn_code, globals_) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..fbafd326c6d4035795a8d4d372e9e3ada71dedef --- /dev/null +++ b/colossalai/fx/graph_module.py @@ -0,0 +1,164 @@ +import os +import warnings +import torch +import torch.nn as nn +from torch.nn.modules.module import _addindent +from typing import Type, Dict, List, Any, Union, Optional, Set +from pathlib import Path +try: + from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src + from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode + COLOGM = True +except: + from torch.fx.graph_module import GraphModule + from torch.fx.graph import Graph + COLOGM = False + +if COLOGM: + + class ColoGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) + + def bind(self, ckpt_def, globals): + """Bind function needed for correctly execute gm forward + + We need to bind checkpoint functions and saved_tensor_hooks functions + to gm so that we could correctly execute gm forward + + Args: + ckpt_def (_type_): definition before the forward function + globals (_type_): global variables + """ + + ckpt_code = "\n".join(ckpt_def) + globals_copy = globals.copy() + _exec_with_source(ckpt_code, globals_copy) + func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func] + for func in func_list: + tmp_func = globals_copy[func] + setattr(self, func, tmp_func.__get__(self, self.__class__)) + del globals_copy[func] + + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module='self') + self._code = python_code.src + + # To split ckpt functions code and forward code + _code_list = self._code.split("\n") + _fwd_def = [item for item in _code_list if "def forward" in item][0] + _fwd_idx = _code_list.index(_fwd_def) + ckpt_def = _code_list[:_fwd_idx] + self._code = "\n".join(_code_list[_fwd_idx:]) + + self.bind(ckpt_def, python_code.globals) + + cls = type(self) + cls.forward = _forward_from_src(self._code, python_code.globals) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if '_wrapped_call' not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped + + # reset self._code to original src, otherwise to_folder will be wrong + self._code = python_code.src + return python_code + + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / 'state_dict.pt') + tab = " " * 4 + + # we add import colossalai here + model_str = f""" +import torch +from torch.nn import * +import colossalai + + +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f'{module_name}.pt' + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / 'module.py' + module_file.write_text(model_str) + + init_file = folder / '__init__.py' + init_file.write_text('from .module import *') + + if len(blobified_modules) > 0: + warnings.warn("Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}") + +else: + + class ColoGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f948cb2d3b37ce73e03044a944732a7d8370d6b --- /dev/null +++ b/colossalai/fx/passes/__init__.py @@ -0,0 +1,4 @@ +from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from .concrete_info_prop import ConcreteInfoProp +from .meta_info_prop import MetaInfoProp, metainfo_trace +from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..503397878be9895fc7990029f8cec98e4691a96c --- /dev/null +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -0,0 +1,142 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx.node import Node + +from colossalai.fx.passes.split_module import split_module + + +def pipe_split(): + pass + + +def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): + """ + In balanced_split_pass, we split module by the size of parameters(weights+bias). + """ + mod_graph = gm.graph + total_param_amount = 0 + for param in mod_graph.owning_module.parameters(): + total_param_amount += param.numel() + params_per_partition = total_param_amount // pp_size + accumulate_param_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + for param in target_module.parameters(): + accumulate_param_amount += param.numel() + if accumulate_param_amount >= params_per_partition: + accumulate_param_amount = 0 + pp_size -= 1 + # If the next node is output node, we will insert split annotation before + # node to make sure there is at least one node in last partition. + if node.next.op == 'output': + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + else: + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + if pp_size > 1: + node_counter = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == 'placeholder': + continue + elif node_counter == 0: + node_counter += 1 + else: + pp_size -= 1 + node_counter = 0 + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + + gm.recompile() + return gm + + +def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): + """ + In balanced_split_pass_v12, we split module by the size of nodes(weights+bias+outputs). + """ + mod_graph = gm.graph + # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first. + # If nodes don't have meta info, this pass will fall back to normal balanced split pass. + check_node = list(mod_graph.nodes)[0] + if 'tensor_meta' not in check_node.meta: + return balanced_split_pass(gm, pp_size) + + total_element_size = 0 + for node in mod_graph.nodes: + total_element_size += node.node_size + + partition_size = total_element_size // pp_size + accumulate_node_size = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if 'pipe_split' in node.name: + continue + accumulate_node_size += node.node_size + if accumulate_node_size >= partition_size: + accumulate_node_size = 0 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for module in mod_graph.owning_module.children(): + valid_children_size += 1 + valid_children.append(module) + + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + layers_per_partition = valid_children_size // pp_size + accumulate_layer_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if target_module in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == layers_per_partition: + accumulate_layer_amount = 0 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False): + # TODO(lyl): use partition IR to assign partition ID to each node. + # Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph + # In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node + part_idx = 0 + + def split_callback(n: torch.fx.Node): + nonlocal part_idx + if (n.op, n.target) == ('call_function', pipe_split): + part_idx += 1 + return part_idx + + split_mod = split_module(annotated_gm, None, split_callback, merge_output) + split_submodules = [] + for name, submodule in split_mod.named_modules(): + if isinstance(submodule, torch.fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ('call_function', pipe_split): + submodule.graph.erase_node(node) + submodule.recompile() + split_submodules.append(submodule) + + return split_mod, split_submodules diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ccf28f28e9395fa73db478924ddc0c3fa7ded3a --- /dev/null +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -0,0 +1,4 @@ +from .ckpt_solver_chen import chen_greedy +from .linearize import linearize +from .ckpt_solver_rotor import solver_rotor +from .ckpt_solver_pofo import solver_pofo diff --git a/colossalai/fx/passes/algorithms/build_c_ext.py b/colossalai/fx/passes/algorithms/build_c_ext.py new file mode 100644 index 0000000000000000000000000000000000000000..cb360cb2034056be414b1fc8253712421ca7c311 --- /dev/null +++ b/colossalai/fx/passes/algorithms/build_c_ext.py @@ -0,0 +1,15 @@ +from setuptools import setup, Extension +import os + +this_dir = os.path.dirname(os.path.abspath(__file__)) +ext_modules = [Extension( + 'dynamic_programs_C_version', + sources=[os.path.join(this_dir, 'dynamic_programs.c')], +)] + +setup( + name='rotor c extension', + version='0.1', + description='rotor c extension for faster dp computing', + ext_modules=ext_modules, +) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py new file mode 100644 index 0000000000000000000000000000000000000000..52000ebe536475f0d5452727589a24f36286f7f9 --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -0,0 +1,98 @@ +import math +from typing import List, Set, Tuple + +import torch +from torch.fx import GraphModule, Node + +from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp + +__all__ = ['chen_greedy'] +CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] + + +def _all_potential_ckpt_nodes(gm: GraphModule) -> List: + """ + In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized. + """ + + def is_sink(): + """ + If we can free all memories when executing a certain node, it is a sink. + """ + return not sum((v for k, v in deps.items())) + + deps = {} + ckpt_nodes = [] + for n in gm.graph.nodes: + for n_par in n._input_nodes: + deps[n_par] -= 1 # free memory and dependencies + + # We can only put act_ckpt on these nodes + if n.op in CKPT_OP and is_sink(): + ckpt_nodes.append(n) + deps[n] = len(n.users) # add dependencies for future executions + return ckpt_nodes + + +def chen_greedy(gm: GraphModule) -> GraphModule: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + Note that this algorithm targets at memory optimization only, using techniques in appendix A. + + Usage: + model = resnet18() + input_sample = torch.rand(4, 3, 224, 224) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + gm = chen_greedy(gm) + + Args: + gm (GraphModule): The module to add checkpoints + """ + + def grid_search(num_grids: int = 6) -> Set: + """ + Search ckpt strategy with b = 0, then run the allocation algorithm again with b = โˆšxy. + Grid search over [โˆš2/2 b, โˆš2 b] for ckpt_opt over num_grids as in appendix A. + """ + _, b_approx = run_chen_greedy(0) + b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) + b_opt = math.inf + for b in range(b_min, b_max, (b_max - b_min) // num_grids): + ckpt_intv, b_approx = run_chen_greedy(b) + if b_approx < b_opt: + b_opt = b_approx + ckpt_opt = ckpt_intv + return ckpt_opt + + def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + """ + ckpt_nodes = _all_potential_ckpt_nodes(gm) + ckpt_intv = [] + temp = 0 + x = 0 + y = 0 + prev_idx = 2 + for (idx, n) in enumerate(gm.graph.nodes): + n: Node + temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) + y = max(y, temp) + if temp > b and n in ckpt_nodes: + x += calculate_fwd_in(n) + temp = 0 + ckpt_intv.append((prev_idx, idx + 1)) + prev_idx = idx + 1 + return ckpt_intv, math.floor(math.sqrt(x * y)) + + gm.graph.lint() # make sure nodes are in topological order + ckpt = grid_search(num_grids=6) + node_list = list(gm.graph.nodes) + for i, seg in enumerate(ckpt): + for idx in range(*seg): + n = node_list[idx] + if n.op in CKPT_OP: + setattr(n, 'activation_checkpoint', i) + gm.recompile() + return gm diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py new file mode 100644 index 0000000000000000000000000000000000000000..69e4e9f2cce8b684a8828cc93a812ba846cc95fa --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py @@ -0,0 +1,537 @@ +import copy +import math +from typing import List, Tuple + +import torch +from colossalai.fx import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import \ + _find_nested_ckpt_regions +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import parameter_size +from torch.fx import GraphModule, Node + +from .linearize import linearize +from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch, + Sequence) + +INF = float("inf") + + +def _normalize_flops(chain: Chain, flops) -> Chain: + """ + Normalize flops + """ + for i in range(chain.length): + chain.fweight[i] /= flops + chain.bweight[i] /= flops + + return chain + + +class PofoTable: + """PofoTable + The PofoTable contains the necessary components to store intermediate results + of dynamic programming and the operations alone the way. + """ + + def __init__(self, chain_length: int, mem_slots: int): + """Init pofo table + The pofo table contains two tables, opt and what, indicating values and + operations. + + Args: + chain_length (int): chain length + mem_slots (int): number of memory slots + """ + + self.length = chain_length + self.mem_slots = mem_slots + + # initializing tables + # the first bool indicates whether the input has bar + # opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db) + # what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index) + # where is_enable indicates whether we enable the gradient, is_offload indicates whether we + # offload the input, index indicates the end of F_\empty sequence if is_enable = False + self.opt = { + False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], + True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] + } + self.what = { + False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], + True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] + } + + def _get_value(self, state, table, default): + i, act_size, df, db, input_has_bar = state + if act_size + df > self.mem_slots or act_size + db > self.mem_slots: + return default + + try: + return table[input_has_bar][i][act_size][(df, db)] + except KeyError: + print(f"state not found {state}") + + def get_opt(self, state): + return self._get_value(state, self.opt, INF) + + def get_what(self, state): + return self._get_value(state, self.what, INF) + + def set_value(self, state, opt, what): + i, act_size, df, db, input_has_bar = state + self.opt[input_has_bar][i][act_size][(df, db)] = opt + self.what[input_has_bar][i][act_size][(df, db)] = what + + +class PofoSolver: + """PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html + The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs + and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse + rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload. + """ + + def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None: + self.chain = chain + self.length = chain.length + self.max_memory = max_memory + self.mem_slots = mem_slots + self.mem_unit = max_memory / mem_slots + self.bandwidth = bandwidth + + self.disc_chain = copy.deepcopy(self.chain) + self.disc_chain._discretize(self.mem_unit) + + self.rotor_table = _compute_table(self.disc_chain, mem_slots) + self._compute_pofo_table() + + def _discretize(self, *values) -> Tuple: + return tuple(math.ceil(value / self.mem_unit) for value in values) + + def _undiscretize(self, *discrete_values) -> Tuple: + if len(discrete_values) == 1: + return discrete_values[0] * self.mem_unit + else: + return tuple(d * self.mem_unit for d in discrete_values) + + def _mmax_all(self, idx: int): + """ + Calculate the maximum memory usage of Fi_all + """ + + return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx] + + def _mmax_b(self, idx: int): + """ + Calculate the maximum memory usage of Bi + """ + + return self.chain.cbweight[idx + + 1] + self.chain.cweight[idx + + 1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx] + + def _mmax_ng(self, i: int, j: int): + """ + Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty + """ + + res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j] + if j > i: + res += self.chain.cweight[j] + return res + + def _rotor_estimated_bwd(self, i, j, m, delta): + compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j] + comm = delta / self.bandwidth + return (max(compute, comm) + compute + comm) / 2 + + def _rotor_estimated_bwd_sequence(self, i, j, m, delta): + return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table) + + def _common_values_enable(self, state: Tuple): + + idx, act_size, df, db, input_has_bar = state + input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] + mf = act_size + df + input_size + mb = act_size + db + input_size + mem_avail = self.max_memory - act_size - input_size + f_usage = self._mmax_all(idx) + b_usage = self._mmax_b(idx) + + # infeasible + if f_usage > mem_avail or b_usage > mem_avail: + return None + + # calculate idle time + eps_f_beta = max(0, f_usage - self.max_memory + mf) + eps_b_beta = max(0, b_usage - self.max_memory + mb) + idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth + + # calculate offload and prefetch data + offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta + prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta + + # total_time + total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time + + return (offload_data, prefetch_data, total_time, idle_time) + + def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False): + + i, act_size, df, db, input_has_bar = state + + # compute new epsilon_tmp and sum_fwds + if iterative: + self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds) + self.sum_fwds += self.chain.fweight[j] + else: + self.epsilon_tmp = max( + self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1)) + self.sum_fwds = sum(self.chain.fweight[i:j + 1]) + + input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i] + mf = act_size + df + input_size + mem_avail = self.max_memory - act_size - input_size + + # if infeasible + if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail: + return None + + eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf) + offload_data = self.sum_fwds * self.bandwidth + eps_f_beta + + # TODO: Implement the precise backward recompute sequence mentioned in the paper + # currently we will use an approximate way to get the backward time + time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db) + + prefetch_data = time_backward * self.bandwidth + idle_time = eps_f_beta / self.bandwidth + total_time = self.sum_fwds + idle_time + time_backward + + return (offload_data, prefetch_data, total_time, idle_time) + + def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple: + """Generate new values for next state + + Args: + state (Tuple): undiscretized states + do_offload (bool): bool type indicates whether we need to do offload + common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time) + + Returns: + Tuple: (new_act_size, new_df, new_db) + """ + idx, act_size, df, db, input_has_bar = state + offload_data, prefetch_data, *_ = common_values + input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] + if do_offload: + new_act_size = act_size + new_df = max(0, df + input_size - offload_data) + new_db = max(0, db - prefetch_data) + input_size + else: + new_act_size = act_size + input_size + new_df = max(0, df - offload_data) + new_db = max(0, db - prefetch_data) + + return (new_act_size, new_df, new_db) + + def _compute_pofo_table(self): + self.table = PofoTable(self.length, self.mem_slots) + + # initializing the loss + for act_size in range(self.mem_slots + 1): + for df in range(self.mem_slots - act_size + 1): + for db in range(self.mem_slots - act_size + 1): + # undiscretize for idle time calculation + origin_values = self._undiscretize(act_size, df, db) + + for input_has_bar in (False, True): + disc_state = (self.length, act_size, df, db, input_has_bar) + state = (self.length, *origin_values, input_has_bar) + common_values = self._common_values_enable(state) + + # if no feasible choice + if common_values is None: + self.table.set_value(disc_state, INF, None) + continue + + # if there is feasible choice + new_act_size, new_df, new_db = self._new_values(state, False, common_values) + eps_g = (new_df + new_db) / self.bandwidth + total_time = common_values[2] + eps_g + self.table.set_value(disc_state, total_time, (True, False)) + + # main loop + for i in reversed(range(self.length)): + for act_size in range(self.mem_slots + 1): + for df in range(self.mem_slots - act_size + 1): + for db in range(self.mem_slots - act_size + 1): + # undiscretize for idle time calculation + origin_values = self._undiscretize(act_size, df, db) + + for input_has_bar in (False, True): + best_result = INF + best_choice = None + disc_state = (i, act_size, df, db, input_has_bar) + state = (i, *origin_values, input_has_bar) + + # case 1: start with F_all + vals_enable = self._common_values_enable(state) + if vals_enable is not None: + for do_offload in (True, False): + new_state = self._new_values(state, do_offload, vals_enable) + new_state = (i + 1, *self._discretize(*new_state), True) + total_time = vals_enable[2] + results_all = self.table.get_opt(new_state) + total_time + if results_all < best_result: + best_result = results_all + best_choice = (True, do_offload) + + # case 2: start with F_ck + self.sum_fwds = 0 + self.epsilon_tmp = 0 + for j in range(i, self.length): + vals_nograd = self._common_values_nograd(state, j, True) + + # if infeasible + if vals_nograd is None: + continue + + for do_offload in (True, False): + new_state = self._new_values(state, do_offload, vals_nograd) + new_state = (j + 1, *self._discretize(*new_state), False) + total_time = vals_nograd[2] + result_nograd = total_time + self.table.get_opt(new_state) + if result_nograd < best_result: + best_result = result_nograd + best_choice = (False, do_offload, j) + + self.table.set_value(disc_state, best_result, best_choice) + + def pofo_rec(self, disc_state): + i, act_size, df, db, input_has_bar = disc_state + result = Sequence(Function("pofo", *disc_state)) + what = self.table.get_what(disc_state) + state = self._undiscretize(act_size, df, db) + state = (i, *state, input_has_bar) + i, act_size, df, db, input_has_bar = state + + if what is None: + return None + + # if loss + if i == self.length: + result.insert(Loss()) + return result + + if what[0]: + do_offload = what[1] + values = self._common_values_enable(state) + new_state = self._discretize(*self._new_values(state, do_offload, values)) + new_state = (i + 1, *new_state, True) + if do_offload: + result.insert(Offload(i, input_has_bar)) + result.insert(ForwardEnable(i)) + result.insert_sequence(self.pofo_rec(new_state)) + if do_offload: + result.insert(Prefetch(i, input_has_bar)) + result.insert(Backward(i)) + + else: + _, do_offload, j = what + values = self._common_values_nograd(state, j) + new_state = self._discretize(*self._new_values(state, do_offload, values)) + new_state = (j + 1, *new_state, False) + if do_offload: + result.insert(Offload(i, input_has_bar)) + result.insert(ForwardCheck(i)) + for k in range(i + 1, j + 1): + result.insert(ForwardNograd(k)) + result.insert_sequence(self.pofo_rec(new_state)) + if do_offload: + result.insert(Prefetch(i, input_has_bar)) + m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]) + + #TODO: Implement the precise backward recompute sequence mentioned in the paper + result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db)) + + return result + + +def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]): + op_list = sequence.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + + # forward annotation + for op in fwd_list: + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for node_idx in ckpt_region: + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", [ckpt_idx]) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", [ckpt_idx]) + + ckpt_idx += 1 + ckpt_region = [op.index] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(op.index) + + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) + + # annotate the offload + offload_idx = 0 + for idx, op in enumerate(fwd_list): + if isinstance(op, Offload): + # corner case: offload input + if op.index == 0: + if isinstance(fwd_list[idx + 1], ForwardCheck): + for n in node_list[op.index]: + setattr(n, "activation_offload", True) + else: + for n in node_list[op.index]: + setattr(n, "activation_offload", (offload_idx, True, False)) + offload_idx += 1 + + else: + if op.has_bar: + # annotate previous node + if hasattr(node_list[op.index - 1][0], "activation_offload"): + for n in node_list[op.index - 1]: + n.activation_offload[-1] = True + else: + for n in node_list[op.index - 1]: + setattr(n, "activation_offload", [offload_idx, False, True]) + + offload_idx += 1 + + # annotate this node + if isinstance(fwd_list[idx + 1], ForwardCheck): + for n in node_list[op.index]: + setattr(n, "activation_offload", True) + else: + for n in node_list[op.index]: + setattr(n, "activation_offload", [offload_idx, True, False]) + + offload_idx += 1 + + +def solver_pofo(gm: ColoGraphModule, + data, + bandwidth, + flops, + mem_limit: int, + mem_slots: int = 50, + cnode: List[str] = None, + eps: float = 0.0) -> ColoGraphModule: + """Solver that combine offload and activation checkpoint + Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html + + Args: + gm (ColoGraphModule): ColoGraphModule derived from tracer + data: input of the model + bandwidth: offload bandwidth, unit Byte/s + flops: FLOPS of device, unit FLOPs/s + mem_limit (int): memory limit, unit Byte + mem_slots (int, optional): number of memory slots. Defaults to 500. + cnode (List[str], optional): common node for linearize. Defaults to None. + eps (float, optional): epsilon for memory decay. Defaults to 0.02. + + Returns: + ColoGraphModule: annotated graph module + """ + + node_list = linearize(gm, cnode) + mem_limit -= parameter_size(gm) + + # prepare data + if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + data = MetaTensor(data, fake_device=next(gm.parameters()).device) + MetaInfoProp(gm).run(data) + chain: Chain = _construct_chain(node_list, data) + chain = _normalize_flops(chain, flops) + # currently we view loss as an op without expense + chain.cbweight.append(0) + chain.cweight.append(0) + chain.fwd_mem_tmp.append(0) + chain.bwd_mem_tmp.append(0) + chain.fweight.append(0) + chain.bweight.append(0) + + solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots) + first_state = (0, 0, 0, 0, False) + sequence = solver.pofo_rec(first_state) + if sequence == None: + raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory") + + _annotate_from_pofo_sequence(sequence, node_list) + setattr(gm, "__sequence__", sequence) + return gm diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8d0da9ffe6665a5135ebb483cc10066ea78acc --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -0,0 +1,436 @@ +import math +import sys +from typing import List, Tuple + +from torch.fx import Node + +from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size +from colossalai.logging import get_dist_logger + +from .linearize import linearize +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence + +# global vairable to indicate whether the solver is failed +SOLVER_FAILED = False + + +# this is the python compute table code from rotor +# https://gitlab.inria.fr/hiepacs/rotor +# paper link: https://hal.inria.fr/hal-02352969 +def _compute_table(chain: Chain, mmax) -> Tuple: + """Returns the optimal table: a tuple containing: + Opt[m][lmin][lmax] with lmin = 0...chain.length + and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax + what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint + (False, j) if the optimal choice is a leaf checkpoint of length j + The computation uses dynamic programming""" + + fw = chain.fweight + [0] ## forward time + bw = chain.bweight ## backward time, not used + cw = chain.cweight + [0] ## size of x (and of y) + cbw = chain.cbweight + [0] ## size of xbar + fwd_mem_tmp = chain.fwd_mem_tmp + [0] + bwd_mem_tmp = chain.bwd_mem_tmp + [0] + + # Build table + opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] + what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] + # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation + + # Initialize borders of the tables for lmax-lmin = 0 + for m in range(mmax + 1): + for i in range(chain.length + 1): + #lmax-lmin = 0 + limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i]) + if m >= limit: ## Equation (1) + opt[m][i][i] = fw[i] + bw[i] + else: + opt[m][i][i] = float("inf") + + # Compute everything + for m in range(mmax + 1): + for d in range(1, chain.length + 1): + for i in range(chain.length + 1 - d): + # for idx in range(i+1, chain.length + 1): + idx = i + d + mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i] + if idx > i + 1: + mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx))) + if m < mmin: + opt[m][i][idx] = float("inf") + else: + leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= cw[j]] + if leaf_checkpoints: + best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) + else: + best_leaf = None + if m >= cbw[i + 1]: + chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx] + else: + chain_checkpoint = float("inf") + if best_leaf and best_leaf[1] <= chain_checkpoint: + opt[m][i][idx] = best_leaf[1] + what[m][i][idx] = (False, best_leaf[0]) + else: + opt[m][i][idx] = chain_checkpoint + what[m][i][idx] = (True,) + return (opt, what) + + +def _rec(chain: Chain, lmin, lmax, cmem, opt_table): + """ chain : the class describing the AC graph + lmin : index of the first forward to execute + lmax : upper bound index of the last forward to execute (not included) + cmem : number of available memory slots + Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]""" + if cmem <= 0: + raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem)) + opt, what = opt_table + sequence = Sequence(Function("Persistent", lmax - lmin, cmem)) + if opt[cmem][lmin][lmax] == float("inf"): + # using logger to annonce that the solver is failed + logger = get_dist_logger() + logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, + lmax=lmax, + cmem=cmem)) + + # set global indicater SOLVER_FAILED to True + global SOLVER_FAILED + SOLVER_FAILED = True + return sequence + + if lmin == lmax: + if lmin == chain.length: + sequence.insert(Loss()) + else: + sequence.insert(ForwardEnable(lmin)) + sequence.insert(Backward(lmin)) + return sequence + + if what[cmem][lmin][lmax][0]: + sequence.insert(ForwardEnable(lmin)) + sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table)) + sequence.insert(Backward(lmin)) + else: + j = what[cmem][lmin][lmax][1] + sequence.insert(ForwardCheck(lmin)) + for k in range(lmin + 1, j): + sequence.insert(ForwardNograd(k)) + sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table)) + sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) + return sequence + + +def _fwd_xbar(node: List[Node]) -> int: + """Get the forward xbar of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: xbar size, unit Byte + """ + + xbar = 0 + for n in node: + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + return xbar + + +def _fwd_time(node: List[Node]) -> int: + """Get the foward time of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: foward time, extimated by flops count + """ + + fwd_time = 0 + for n in node: + # minimum flop count is needed + fwd_time += max(n.meta['fwd_flop'], 1) + return fwd_time + + +def _bwd_time(node: List[Node]) -> int: + """Get the backward time of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: backward time, extimated by flops count + """ + + bwd_time = 0 + for n in node: + # minimum flop count is needed + bwd_time += max(n.meta['bwd_flop'], 1) + return bwd_time + + +def _get_fwd_mem_tmp(node: List[Node]) -> int: + """Get the forward temp memory of a node + This could be done by subtracting the saved activation from all output of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: forward temp memory, unit Byte + """ + n = node[-1] + return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) + + +def _get_bwd_mem_tmp(node: List[Node]) -> int: + """Get the backward temp memory of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: backward temp memory, unit Byte + """ + + def _get_deps_size(): + deps_size = 0 + for k, v in deps.items(): + k: Node + if v > 0: + deps_size += k.meta['bwd_mem_out'] + if v == float('-inf'): + deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) + + return deps_size + + bwd_mem_tmp = 0 + deps = {} + + for n in reversed(node): + deps[n] = len(n.all_input_nodes) + bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) + + for child in n.users: + if child in deps: + deps[child] -= 1 + if deps[child] <= 0: + deps[child] = float('-inf') # free + + return bwd_mem_tmp + + +def _construct_chain(node_list: List[List[Node]], input) -> Chain: + + fwd_time = [] + bwd_time = [] + xbar_sizes = [activation_size(input)] + x_sizes = [activation_size(input)] + tmp_fwd = [] + tmp_bwd = [] + + for idx, node in enumerate(node_list): + fwd_time.append(_fwd_time(node)) + bwd_time.append(_bwd_time(node)) + x_sizes.append(calculate_fwd_out(node[-1])) + xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) + tmp_fwd.append(_get_fwd_mem_tmp(node)) + tmp_bwd.append(_get_bwd_mem_tmp(node)) + + bwd_time.append(0) + + # currently we view loss backward temp as zero + tmp_bwd.append(0) + + return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) + + +def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): + op_list = sequence.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + + # forward annotation + for idx, op in enumerate(fwd_list, 0): + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(idx) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for node_idx in ckpt_region: + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", [ckpt_idx]) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", [ckpt_idx]) + + ckpt_idx += 1 + ckpt_region = [idx] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(idx) + + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) + + +def solver_rotor(gm: ColoGraphModule, + data, + mem_limit: int, + mem_slots: int = 500, + cnode: List[str] = None, + eps: float = 0.0, + force_python: bool = False) -> ColoGraphModule: + """solver that automatically find activation checkpoint in rotor's manner + + Args: + gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp. + data (torch.Tensor): input data. + mem_limit (int): memory budget in Byte. + mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. + cnode (List[Node], optional): common node list for linearize. Defaults to None. + eps (float): epsilon for memory decay. Defaults to 0.0 + force_python (bool): force to use python version of dynamic programs + + Returns: + ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute + """ + + # try to import C version solver if force_python is not set + logger = get_dist_logger() + if not force_python: + try: + from .dynamic_programs_C_version import persistent_compute_table + CVERSION = True + + # build module if module not found + except ModuleNotFoundError: + import os + import subprocess + logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0]) + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = subprocess.Popen( + [ + f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", + f"--build-lib={this_dir}" + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.wait() == 0: + logger.info("dynamic_programs_C_version has been built!", ranks=[0]) + from .dynamic_programs_C_version import persistent_compute_table + CVERSION = True + else: + logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0]) + CVERSION = False + else: + CVERSION = False + + # check if metainfoprop is done + if any(len(node.meta) == 0 for node in gm.graph.nodes): + raise RuntimeError( + "Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!") + + # linearize the graph + node_list = linearize(gm, cnode) + + # construct chain + mem_unit = mem_limit * (1.0 - eps) // mem_slots + chain: Chain = _construct_chain(node_list, data) + chain._discretize(mem_unit) + + # use C version if possible + if CVERSION and not force_python: + logger.info("Using C version rotor solver!", ranks=[0]) + opt_table = persistent_compute_table(chain, mem_slots) + else: + opt_table = _compute_table(chain, mem_slots) + logger.info("Using python version rotor solver!", ranks=[0]) + + # found sequence + sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) + + # if solver failed, we don't need to annotate the graph + if not SOLVER_FAILED: + _annotate_from_sequence(sequence, node_list) + + # set __sequence__ attribute to GraphModule + if SOLVER_FAILED: + setattr(gm, "__sequence__", None) + else: + setattr(gm, "__sequence__", sequence) + + # set __opttable__ attribute to GraphModule + setattr(gm, "__opttable__", opt_table[0]) + gm.recompile() + return gm diff --git a/colossalai/fx/passes/algorithms/dynamic_programs.c b/colossalai/fx/passes/algorithms/dynamic_programs.c new file mode 100644 index 0000000000000000000000000000000000000000..3efad58400fa40e377a9c1e740478bb4d1e9f6a3 --- /dev/null +++ b/colossalai/fx/passes/algorithms/dynamic_programs.c @@ -0,0 +1,516 @@ +#define PY_SSIZE_T_CLEAN +#include + +long* PySequenceToLongArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + long* result = (long*)calloc(len + 1, sizeof(long)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyLong_AsLong(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +double* PySequenceToDoubleArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + double* result = (double*)calloc(len + 1, sizeof(double)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyFloat_AsDouble(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +long* getLongArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + long* result = PySequenceToLongArray(sequence); + Py_DECREF(sequence); + return result; +} + +double* getDoubleArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + double* result = PySequenceToDoubleArray(sequence); + Py_DECREF(sequence); + return result; +} + +static PyObject* persistent_compute_table(PyObject* self, PyObject* args) { + PyObject* chain_param; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; + + double* fw = getDoubleArray(chain_param, "fweight"); + if (!fw) return NULL; + + double* bw = getDoubleArray(chain_param, "bweight"); + if (!bw) return NULL; + + long* cw = getLongArray(chain_param, "cweight"); + if (!cw) return NULL; + + long* cbw = getLongArray(chain_param, "cbweight"); + if (!cbw) return NULL; + + long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp"); + if (!cbw) return NULL; + + long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp"); + if (!cbw) return NULL; + + PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); + if (!chain_length_param) return NULL; + long chain_length = PyLong_AsLong(chain_length_param); + Py_DECREF(chain_length_param); + + // TODO: Can be optimized by only allocating memory for l >= i + // TODO: float / int instead of double / long ? +#define OPT(m, i, l) \ + opt[(m) * (chain_length + 1) * (chain_length + 1) + \ + (i) * (chain_length + 1) + (l)] + double* opt = (double*)calloc( + (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double)); + +#define WHAT(m, i, l) \ + what[(m) * (chain_length + 1) * (chain_length + 1) + \ + (i) * (chain_length + 1) + (l)] + long* what = (long*)calloc( + (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long)); + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) + // TODO: Can be optimized to remove the IF by reordering loops + if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) && + (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i])) + OPT(m, i, i) = fw[i] + bw[i]; + else + OPT(m, i, i) = INFINITY; + + for (long m = 0; m <= mmax; ++m) + for (long d = 1; d <= chain_length; ++d) { + for (long i = 0; i <= chain_length - d; ++i) { + long idx = i + d; + long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]; + if (idx > i + 1) { + long maxCostFWD = 0; + for (long j = i + 1; j < idx; j++) { + maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]); + } + mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD); + } + if ((m >= mmin)) { + long bestLeaf = -1; + double sumFw = 0; + double bestLeafCost = INFINITY; + /// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j = + /// i+1 + for (long j = i + 1; j <= idx; ++j) { + sumFw += fw[j - 1]; + if (m >= cw[j]) { + double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1); + if (cost < bestLeafCost) { + bestLeafCost = cost; + bestLeaf = j; + } + } + } + double chainCost = INFINITY; + if (m >= cbw[i + 1]) + chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx); + if (bestLeafCost <= chainCost) { + OPT(m, i, idx) = bestLeafCost; + WHAT(m, i, idx) = bestLeaf; + } else { + OPT(m, i, idx) = chainCost; + WHAT(m, i, idx) = -1; + } + } else + OPT(m, i, idx) = INFINITY; + } + } + + free(fw); + free(bw); + free(cw); + free(cbw); + free(fwd_tmp); + free(bwd_tmp); + + PyObject* res_opt = PyList_New(mmax + 1); + PyObject* res_what = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* res_opt_m = PyList_New(chain_length + 1); + PyList_SET_ITEM(res_opt, m, res_opt_m); + PyObject* res_what_m = PyList_New(chain_length + 1); + PyList_SET_ITEM(res_what, m, res_what_m); + for (long i = 0; i <= chain_length; ++i) { + PyObject* res_opt_m_i = PyDict_New(); + PyList_SET_ITEM(res_opt_m, i, res_opt_m_i); + PyObject* res_what_m_i = PyDict_New(); + PyList_SET_ITEM(res_what_m, i, res_what_m_i); + for (long l = i; l <= chain_length; ++l) { + PyObject* res_l = PyLong_FromLong(l); + PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l)); + PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l); + Py_DECREF(res_opt_m_i_l); + PyObject* res_what_m_i_l; + long what_m_i_l = WHAT(m, i, l); + if (what_m_i_l < 0) + res_what_m_i_l = Py_BuildValue("(O)", Py_True); + else + res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l); + PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l); + Py_DECREF(res_what_m_i_l); + Py_DECREF(res_l); + } + } + } + + free(opt); + free(what); + + PyObject* result = PyTuple_Pack(2, res_opt, res_what); + Py_DECREF(res_opt); + Py_DECREF(res_what); + return result; +} + +// long i = L - s, j = t - s, k = l - t +inline long floating_index_in_array(long m_factor, long m, long i, long j, + long k) { + return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j - + (j * (j - 1)) / 2 + k; +} + +typedef struct { + long sp; + long r; + long tp; +} index_t; + +static PyObject* floating_compute_table(PyObject* self, PyObject* args) { + PyObject* chain_param; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; + + double* fw = getDoubleArray(chain_param, "fweigth"); + if (!fw) return NULL; + + double* bw = getDoubleArray(chain_param, "bweigth"); + if (!bw) return NULL; + + long* cw = getLongArray(chain_param, "cweigth"); + if (!cw) return NULL; + + long* cbw = getLongArray(chain_param, "cbweigth"); + if (!cbw) return NULL; + + long* fwd_tmp = getLongArray(chain_param, "fwd_tmp"); + if (!fwd_tmp) return NULL; + + long* bwd_tmp = getLongArray(chain_param, "bwd_tmp"); + if (!bwd_tmp) return NULL; + + PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); + if (!chain_length_param) return NULL; + long chain_length = PyLong_AsLong(chain_length_param); + Py_DECREF(chain_length_param); + + const long m_factor = + (chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12; + + // Defined for 0 <= s <= t <= l <= chain_length, for all m +#undef OPT +#define OPT(m, s, t, l) \ + opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \ + (l) - (t))] + double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double)); + +#undef WHAT +#define WHAT(m, s, t, l) \ + what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \ + (l) - (t))] + index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t)); + + double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double)); + double total = 0; + for (long i = 0; i < chain_length; ++i) { + partialSumsFW[i] = total; + total += fw[i]; + } + partialSumsFW[chain_length] = total; + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) { + // TODO: Can be optimized to remove the IF by reordering loops + if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) && + (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i])) + OPT(m, i, i, i) = fw[i] + bw[i]; + else + OPT(m, i, i, i) = INFINITY; + } + + for (long m = 0; m <= mmax; ++m) + for (long d = 1; d <= chain_length; ++d) { // d = l - s + for (long s = 0; s <= chain_length - d; ++s) { + long l = s + d; + long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s]; + long memNullSecond = 0; + for (long j = s + 1; j < l; ++j) { + long val = cw[j] + cw[j + 1] + fwd_tmp[j]; + if (val > memNullSecond) memNullSecond = val; + } + for (long t = s; t <= l; ++t) { + double chainCost = INFINITY; + if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) && + (m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) { + chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l); + } + double bestLeafCost = INFINITY; + index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1}; + if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) { + for (long r = s; r <= t; ++r) + if (cw[s] <= cw[r]) + for (long tp = t + 1; tp <= l; ++tp) + for (long sp = r + 1; sp <= tp; ++sp) { + long mp = m - cw[r] + cw[s]; + assert(mp >= 0); + if (mp >= cw[sp]) { + double value = partialSumsFW[sp] - partialSumsFW[s] + + OPT(mp - cw[sp], sp, tp, l) + + OPT(mp, r, t, tp - 1); + if (value < bestLeafCost) { + bestLeafCost = value; + bestLeaf.sp = sp; + bestLeaf.r = r; + bestLeaf.tp = tp; + } + } + } + } + if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) { + OPT(m, s, t, l) = bestLeafCost; + WHAT(m, s, t, l).sp = bestLeaf.sp; + WHAT(m, s, t, l).r = bestLeaf.r; + WHAT(m, s, t, l).tp = bestLeaf.tp; + } else { + OPT(m, s, t, l) = chainCost; + WHAT(m, s, t, l).sp = -1; + } + } + } + } + + free(fw); + free(bw); + free(cw); + free(cbw); + free(fwd_tmp); + free(bwd_tmp); + + PyObject* res_opt = PyList_New(mmax + 1); + PyObject* res_what = PyList_New(mmax + 1); + + // Convert the result into Python world + PyObject* true_tuple = Py_BuildValue("(O)", Py_True); + for (long m = 0; m <= mmax; ++m) { + PyObject* res_opt_m = PyDict_New(); + PyList_SET_ITEM(res_opt, m, res_opt_m); + PyObject* res_what_m = PyDict_New(); + PyList_SET_ITEM(res_what, m, res_what_m); + for (long s = 0; s <= chain_length; ++s) + for (long t = s; t <= chain_length; ++t) + for (long l = t; l <= chain_length; ++l) { + PyObject* key = Py_BuildValue("(lll)", s, t, l); + PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l)); + PyDict_SetItem(res_opt_m, key, value_opt); + PyObject* value_what = true_tuple; + index_t* idx_what = &WHAT(m, s, t, l); + if (idx_what->sp >= 0) + value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp, + idx_what->r, idx_what->tp); + PyDict_SetItem(res_what_m, key, value_what); + if (value_what != true_tuple) Py_DECREF(value_what); + Py_DECREF(key); + Py_DECREF(value_opt); + } + } + + Py_DECREF(true_tuple); + + free(opt); + free(what); + + PyObject* result = PyTuple_Pack(2, res_opt, res_what); + Py_DECREF(res_opt); + Py_DECREF(res_what); + return result; +} + +static PyObject* griewank_heterogeneous_compute_table(PyObject* self, + PyObject* args) { + PyObject* chain_param; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; + + double* fw = getDoubleArray(chain_param, "fweigth"); + if (!fw) return NULL; + + double* bw = getDoubleArray(chain_param, "bweigth"); + if (!bw) return NULL; + + long* cw = getLongArray(chain_param, "cweigth"); + if (!cw) return NULL; + + long* cbw = getLongArray(chain_param, "cbweigth"); + if (!cbw) return NULL; + + PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); + if (!chain_length_param) return NULL; + long chain_length = PyLong_AsLong(chain_length_param); + Py_DECREF(chain_length_param); + + // TODO: Can be optimized by only allocating memory for l >= i + // TODO: float / int instead of double / long ? +#undef OPT +#define OPT(m, i, l) \ + opt[(m) * (chain_length + 1) * (chain_length + 1) + \ + (i) * (chain_length + 1) + (l)] + double* opt = (double*)calloc( + (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double)); + + // Compute partial sums + double* sumfw = (double*)calloc(chain_length, sizeof(double)); + double* sumbw = (double*)calloc(chain_length + 1, sizeof(double)); + double* sumsumfw = (double*)calloc(chain_length, sizeof(double)); + + double total = 0; + for (long i = 0; i < chain_length; ++i) { + total += fw[i]; + sumfw[i] = total; + } + + total = 0; + for (long i = 0; i < chain_length + 1; ++i) { + total += bw[i]; + sumbw[i] = total; + } + + total = 0; + for (long i = 0; i < chain_length; ++i) { + total += sumfw[i]; + sumsumfw[i] = total; + } + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) { + // TODO: Can be optimized to remove the IF by reordering loops + if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1])) + OPT(m, i, i) = bw[i]; + else + OPT(m, i, i) = INFINITY; + + if (i < chain_length) { + long maxC = fmaxl(cw[i], cw[i + 1]); + long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC); + if ((m >= cbw[i]) && (m >= cw[i] + maxCB)) + OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1]; + else + OPT(m, i, i + 1) = INFINITY; + } + } + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i + 2 <= chain_length; ++i) { + long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]); + long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]); + long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il; + for (long l = i + 2; l <= chain_length; ++l) { + maxCW_il = fmax(maxCW_il, cw[l + 1]); + maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il); + long mmin = fmaxl(mminCst, maxCostFWD); + if ((m >= mmin)) { + double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0); + noCheckpointCost += + sumsumfw[l - 1] - + (i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0); + + double valueCost = INFINITY; + if (m >= cw[i]) { + double sumFwds = 0; + for (long j = i + 1; j < l; ++j) { + sumFwds += fw[j - 1]; + valueCost = fmin( + valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1)); + } + } + OPT(m, i, l) = fmin(noCheckpointCost, valueCost); + } else + OPT(m, i, l) = INFINITY; + } + } + + free(sumfw); + free(sumbw); + free(sumsumfw); + free(fw); + free(bw); + free(cw); + free(cbw); + + PyObject* res_opt = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* res_opt_m = PyList_New(chain_length + 1); + PyList_SET_ITEM(res_opt, m, res_opt_m); + for (long i = 0; i <= chain_length; ++i) { + PyObject* res_opt_m_i = PyDict_New(); + PyList_SET_ITEM(res_opt_m, i, res_opt_m_i); + for (long l = i; l <= chain_length; ++l) { + PyObject* res_l = PyLong_FromLong(l - i); + PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l)); + PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l); + Py_DECREF(res_opt_m_i_l); + Py_DECREF(res_l); + } + } + } + + free(opt); + + return res_opt; +} + +static PyMethodDef dynamic_programs_methods[] = { + {"persistent_compute_table", persistent_compute_table, METH_VARARGS, + "Compute the optimal table with the persistent algorithm."}, + {"floating_compute_table", floating_compute_table, METH_VARARGS, + "Compute the optimal table with the floating algorithm."}, + {"griewank_heterogeneous_compute_table", + griewank_heterogeneous_compute_table, METH_VARARGS, + "Compute the optimal table for the Griewank Heterogeneous Model."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef dynamic_programs_module = { + PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + dynamic_programs_methods}; + +PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) { + return PyModule_Create(&dynamic_programs_module); +} diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..1a49364f5a7c7f5942036a85e6bfbef109a2b08e --- /dev/null +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -0,0 +1,94 @@ +from typing import List, Any +from torch.fx import GraphModule, Node +from colossalai.fx.profiler import is_inplace + +# Common nodes are type of nodes that could be seen as attributes and remain +# unchanged throughout the whole model, it will be used several times by +# different blocks of model, so that it is hard for us to linearize the graph +# when we encounter those kinds of nodes. We let users to annotate some of the +# input as common node, such as attention mask, and the followings are some of +# the ops that could actually be seen as common nodes. With our common node prop, +# we could find some of the "real" common nodes (e.g. the real attention mask +# used in BERT and GPT), the rule is simple, for node who's parents are all common +# nodes or it's op belongs to the following operations, we view this node as a +# newly born common node. +# List of target name that could be seen as common node +COPS = ["getattr", "getitem", "size"] + + +def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in COPS + else: + return target.__name__ in COPS + + +def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: + """Linearizing the graph + + Args: + gm (GraphModule): GraphModule derived by tracing + cnode (List[str], optional): common node List, should be the subset of input. Default to None. + + Returns: + List[List[Node]]: List of list, each inside list of Node presents + the actual 'node' in linearized manner. + + Remarks: + We merge the inplace ops into the previous node. + """ + + def _is_sink() -> bool: + """Check if we can free all dependencies + + Returns: + bool + """ + + return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + + # make sure that item in cnode is valid + if cnode: + for name in cnode: + try: + assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \ + f"common node {name} is not an input of the model" + except StopIteration: + raise ValueError(f"common node name {name} not in graph") + + else: + cnode = [] + + deps = {} + linearized_nodes = [] + region = [] + + for n in gm.graph.nodes: + if n.op != "placeholder" and n.op != "output": + for n_par in n._input_nodes: + if n_par.op != "placeholder" and n_par.name not in cnode: + deps[n_par] -= 1 + region.append(n) + + # if the node could free all dependencies in graph + # we could begin a new node + if _is_sink(): + linearized_nodes.append(region) + region = [] + + # propagate common node attr if possible + if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target): + cnode.append(n.name) + else: + deps[n] = len([user for user in n.users if user.op != "output"]) + + return linearized_nodes diff --git a/colossalai/fx/passes/algorithms/operation.py b/colossalai/fx/passes/algorithms/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfa3452ba649b4e5c3664007d9829b66bbefba1 --- /dev/null +++ b/colossalai/fx/passes/algorithms/operation.py @@ -0,0 +1,270 @@ +import math + + +def _discretize(mem_unit, values): + return [math.ceil(value / mem_unit) for value in values] + + +class Chain: + + def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): + self.fweight = fw + self.bweight = bw + self.cweight = cw + self.cbweight = cbw + self.fwd_mem_tmp = ftmp + self.bwd_mem_tmp = btmp + self.length = len(fw) + if check and not self.check_lengths(): + raise AttributeError("In Chain, input lists do not have consistent lengths") + + def check_lengths(self): + return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) + and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length) + and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) + + def __repr__(self): + chain_list = [] + for i in range(self.length): + chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i], + self.bwd_mem_tmp[i])) + i = self.length + chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i])) + return chain_list.__repr__() + + def _discretize(self, mem_unit): + self.cweight = _discretize(mem_unit, self.cweight) + self.cbweight = _discretize(mem_unit, self.cbweight) + self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp) + self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp) + + +class Operation: + + def shift(self, value): + if type(self.index) is tuple: + self.index = tuple(x + value for x in self.index) + else: + self.index += value + + +class Offload(Operation): + + def __init__(self, index, has_bar=False) -> None: + super().__init__() + self.index = index + self.name = "Off" + self.has_bar = has_bar + if self.has_bar: + self.name += "wBar" + + def __repr__(self): + return f"{self.name}_{self.index}" + + +class Prefetch(Operation): + + def __init__(self, index, has_bar=False) -> None: + super().__init__() + self.index = index + self.name = "Pre" + self.has_bar = has_bar + if self.has_bar: + self.name += "wBar" + + def __repr__(self): + return f"{self.name}_{self.index}" + + +class Forward(Operation): + + def __init__(self, index): + self.index = index + self.name = "F" + + def __repr__(self): + return "{n}_{i}".format(n=self.name, i=self.index) + + def cost(self, chain: Chain): + if chain is not None: + return chain.fweight[self.index] + else: + return 1 + + +class ForwardEnable(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "Fe" + + +class ForwardNograd(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "Fn" + + +class ForwardCheck(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "CF" + + +class Forwards(Operation): + + def __init__(self, start, end): + self.index = (start, end) + + def __repr__(self): + return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) + + def cost(self, chain: Chain): + if chain is not None: + return sum(chain.fweight[self.index[0]:self.index[1] + 1]) + else: + return (self.index[1] - self.index[0] + 1) + + +def isForward(op): + return type(op) is Forward or type(op) is Forwards + + +class Backward(Operation): + + def __init__(self, index): + self.index = index + + def __repr__(self): + return "B_{i}".format(i=self.index) + + def cost(self, chain: Chain): + if chain is not None: + return chain.bweight[self.index] + else: + return 1 + + +class Loss(Operation): + + def __init__(self): + pass + + def __repr__(self): + return "L" + + def cost(self, chain): + return 0 + + +class MemoryAccess(Operation): + + def __init__(self, index): + self.index = index + + def __repr__(self): + return "{n}_{i}".format(n=self.name, i=self.index) + + def cost(self, chain: Chain): + return 0 + + +class WriteMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "WM" + + +class ReadMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "RM" + + +class DiscardMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "DM" + + +class Function: + + def __init__(self, name, *args): + self.name = name + self.args = args + self.str_args = ','.join(str(v) for v in self.args) + + def __repr__(self): + return "{n}({args})".format(n=self.name, args=self.str_args) + + +class Sequence: + + def __init__(self, function): + self.sequence = [] #List of Operation and Sequence + self.function = function #Description the function (name and parameters) + + def __repr__(self): + return repr(self.list_operations()) + + def list_operations(self): + op_list = [] + for x in self.sequence: + if isinstance(x, Operation): + op_list.append(x) + else: + assert isinstance(x, Sequence) + op_list += x.list_operations() + return op_list + + def insert(self, operation): + self.sequence.append(operation) + + def remove(self, operation_index): + del self.sequence[operation_index] + + def insert_sequence(self, sequence): + self.sequence.append(sequence) + + def shift(self, value): + for x in self.sequence: + x.shift(value) + return self + + def remove_useless_write(self): + if self.sequence: + if isinstance(self.sequence[0], WriteMemory): + self.remove(0) + return self + + def get_makespan(self, chain): + return sum(op.cost(chain) for op in self.list_operations()) + + def without_suffix(self): + ops = self.list_operations() + end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] + try: + last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable) + except ValueError: + last_idx = -1 + if last_idx == end_of_first_phase - 1: + return (self, None) + chain_length = ops[end_of_first_phase - + 1].index ## Some assumption here about the sequence (finishes with Forward_L + start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice + result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain)) + for i in range(last_idx + 1): + result.insert(ops[i]) + result.insert(Loss()) + for i in range(chain_length, start_of_fwd_enable_chain - 1, -1): + position = end_of_first_phase + 1 + (chain_length - i) + assert type(ops[position]) is Backward + assert ops[position].index == i + for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)): + result.insert(ops[i]) + return (result, start_of_fwd_enable_chain) diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..ab38e8cb14e9436363a9de0491cb6cb5e55ba131 --- /dev/null +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -0,0 +1,290 @@ +from dataclasses import asdict +from typing import Any, Dict, List, NamedTuple, Optional, Tuple + +import torch +import torch.fx +from torch.fx.node import Argument, Node, Target +from torch.utils._pytree import tree_flatten + +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module + + +@compatibility(is_backward_compatible=True) +class ConcreteInfoProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node with concrete tensor and record the memory + usage, execution time of forward and backward, and type of the result into + the corresponding node. + + Usage: + BATCH_SIZE = 2 + DIM_IN = 4 + DIM_HIDDEN = 16 + DIM_OUT = 16 + model = torch.nn.Sequential( + torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_HIDDEN, DIM_OUT), + ).cuda() + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda") + gm = symbolic_trace(model) + interp = ConcreteInfoProp(gm) + interp.run(input_sample) + print(interp.summary(unit='kb')) + + + output of above code is + Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- --------- + placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 0.0003993511199951172 s 0.00706791877746582 s False 0.50 KB 0.00 KB 0.03 KB 0.66 KB + call_module _1 6.29425048828125e-05 s 0.00018286705017089844 s False 0.50 KB 0.00 KB 0.12 KB 0.81 KB + output output 0.0 s 0.0 s True 0.00 KB 0.00 KB 0.00 KB 0.00 KB + Args: + module (GraphModule): The module to be executed + + """ + + _is_proped: bool = False + + def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: + """Customized run for ConcreteInfoProp + We need to store the device in self.device + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + + flatten_args, _ = tree_flatten(args) + self.device = next(item for item in flatten_args if hasattr(item, "device")).device + return super().run(*args, initial_env, enable_io_processing) + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + self._is_proped = True + result, meta_info = super().run_node(n) + + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + # TODO: the attribute node_size should be removed in the future + setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) + n.meta['type'] = type(result) + + # retain the autograd graph + for param in self.module.parameters(): + param.grad = None + + return result + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return super().placeholder(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().get_attr(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + assert not isinstance(target, str) + return profile_function(target, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return profile_method(target, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + # Retrieve executed args and kwargs values from the environment + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return profile_module(submod, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return args[0], GraphInfo(save_fwd_in=True) + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) + + def summary(self, unit: str = 'MB') -> str: + """ + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + + def mem_repr(mem: int) -> str: + unit_divisor_map = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + } + return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" + + def time_repr(time: float): + return f"{time:,} s" + + for node in self.module.graph.nodes: + node: Node + node_summaries.append([ + node.op, + str(node), + time_repr(node.meta['fwd_time']), + time_repr(node.meta['bwd_time']), + node.meta['save_fwd_in'], + mem_repr(node.meta['fwd_mem_out']), + mem_repr(node.meta['fwd_mem_tmp']), + mem_repr(node.meta['bwd_mem_out']), + mem_repr(node.meta['bwd_mem_tmp']), + ]) + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Forward time', + 'Backward time', + 'SAVE_FWD_IN', + 'FWD_OUT', + 'FWD_TMP', + 'BWD_OUT', + 'BWD_TMP', + ] + + return tabulate(node_summaries, headers=headers, stralign='right') diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f28d65e2668ac39e7b189c7d181d018468648614 --- /dev/null +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py @@ -0,0 +1,111 @@ +import torch +from typing import List +from torch.fx import symbolic_trace +from torch.fx.node import Node +from colossalai.fx.passes.split_module import split_module +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +import builtins +import operator +from copy import deepcopy + + +def apply(*args, **kwargs): + shape_consistency_manager = ShapeConsistencyManager() + return shape_consistency_manager.apply(*args, **kwargs) + + +def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + # the dict to get origin sharding spec of node + origin_node_sharding_spec_dict = {} + for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): + strategies_vector = node.strategies_vector + setattr(node, 'best_strategy', strategies_vector[strategy_index]) + setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) + origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec + + # apply the sharding spec of parameters + for node in nodes: + if node.op == 'call_module': + target_module = node.graph.owning_module.get_submodule(node.target) + origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) + setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) + target_weight_sharding_spec = node.best_strategy.input_shardings[1] + target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) + apply(target_module.weight, target_weight_sharding_spec) + target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) + + # the dict to get input sharding specs of user node + sharding_spec_convert_dict = {} + for index, node in enumerate(nodes): + target_sharding_specs = [] + for user_node in node.strategies_vector.successor_nodes: + node_index = user_node.strategies_vector.predecessor_nodes.index(node) + target_sharding_spec = user_node.best_strategy.input_shardings[node_index] + target_sharding_specs.append(target_sharding_spec) + sharding_spec_convert_dict[index] = target_sharding_specs + + # add above dicts into graph + for node in nodes: + if node.op != 'placeholder': + with mod_graph.inserting_before(node): + input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') + origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + break + + return sharding_spec_convert_dict, origin_node_sharding_spec_dict + + +def shape_consistency_pass(gm: torch.fx.GraphModule): + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + input_dict_node = None + origin_dict_node = None + + # mapping the node into the origin graph index + node_to_index_dict = {} + index = 0 + for node in nodes: + if node.target == 'sharding_spec_convert_dict': + input_dict_node = node + continue + if node.target == 'origin_node_sharding_spec_dict': + origin_dict_node = node + continue + if not hasattr(node, 'best_strategy'): + continue + node_to_index_dict[node] = index + index += 1 + assert input_dict_node is not None + + # add shape consistency apply function into graph + for node in nodes: + if not hasattr(node, 'best_strategy'): + continue + with mod_graph.inserting_after(node): + origin_spec_node = mod_graph.create_node('call_function', + operator.getitem, + args=(origin_dict_node, node_to_index_dict[node])) + with mod_graph.inserting_after(origin_spec_node): + set_sharding_spec_node = mod_graph.create_node('call_function', + builtins.setattr, + args=(node, 'sharding_spec', origin_spec_node)) + + for user_node in node.strategies_vector.successor_nodes: + node_index = user_node.strategies_vector.predecessor_nodes.index(node) + with mod_graph.inserting_before(user_node): + input_specs_node = mod_graph.create_node('call_function', + operator.getitem, + args=(input_dict_node, node_to_index_dict[node])) + with mod_graph.inserting_before(user_node): + sharding_spec_node = mod_graph.create_node('call_function', + operator.getitem, + args=(input_specs_node, node_index)) + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) + + return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..5137494ada6fb47207d7f96e8e93c1e96cc91123 --- /dev/null +++ b/colossalai/fx/passes/meta_info_prop.py @@ -0,0 +1,352 @@ +from dataclasses import asdict +from typing import Any, Dict, List, NamedTuple, Tuple + +import torch +import torch.fx +from torch.fx.node import Argument, Node, Target +from torch.utils._pytree import tree_map + +from colossalai.fx._compatibility import compatibility, is_compatible_with_meta +from colossalai.fx.profiler import ( + GraphInfo, + activation_size, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_tmp, + profile_function, + profile_method, + profile_module, +) + + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: Tuple[int] + numel: int + is_tensor: bool + # TODO: we can add a list of sharding spec here, and record the sharding + # behaviour by appending sharding spec into list. + + +def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() + numel = result.numel() + is_tensor = True + + return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor) + + +@compatibility(is_backward_compatible=True) +class MetaInfoProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node with meta tensor and + record the memory usage, FLOPs, and type of the result + into the corresponding node. + + Usage: + BATCH_SIZE = 2 + DIM_IN = 4 + DIM_HIDDEN = 16 + DIM_OUT = 16 + model = torch.nn.Sequential( + torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_HIDDEN, DIM_OUT), + ) + input_sample = torch.rand(BATCH_SIZE, DIM_IN) + gm = symbolic_trace(model) + interp = MetaInfoProp(gm) + interp.run(input_sample) + print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB + + + # output of above code is + Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- --------------- ---------------- --------- --------- --------- --------- + placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB + call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB + output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB + Args: + module (GraphModule): The module to be executed + + """ + + _is_proped: bool = False + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + self._is_proped = True + result, meta_info = super().run_node(n) + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + return _extract_tensor_metadata(obj) + else: + return TensorMetadata(None, None, False, None, 0, False) + + tensor_meta = tree_map(extract_tensor_meta, result) + n.meta['tensor_meta'] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + # TODO: the attribute node_size should be removed in the future + setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) + n.meta['type'] = type(result) + + # retain the autograd graph + for param in self.module.parameters(): + param.grad = None + + return result + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().placeholder(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().get_attr(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + assert not isinstance(target, str) + return profile_function(target)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return profile_method(target)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + # Retrieve executed args and kwargs values from the environment + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return profile_module(submod)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + if hasattr(args[0], '_tensor'): + return args[0], GraphInfo(fwd_in=[args[0]._tensor]) + return args[0], GraphInfo(save_fwd_in=True) + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) + + def summary(self, unit: str = 'MB') -> str: + """ + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + + def mem_repr(mem: int) -> str: + unit_divisor_map = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + } + return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" + + def flops_repr(flop: int) -> str: + return f"{flop:,} FLOPs" + + for node in self.module.graph.nodes: + node: Node + node_summaries.append([ + node.op, + str(node), + flops_repr(node.meta['fwd_flop']), + flops_repr(node.meta['bwd_flop']), + mem_repr(calculate_fwd_in(node)), + mem_repr(calculate_fwd_out(node)), + mem_repr(calculate_fwd_tmp(node)), + mem_repr(node.meta['bwd_mem_out']), + mem_repr(node.meta['bwd_mem_tmp']), + ]) + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Forward FLOPs', + 'Backward FLOPs', + 'FWD_IN', + 'FWD_OUT', + 'FWD_TMP', + 'BWD_OUT', + 'BWD_TMP', + ] + + return tabulate(node_summaries, headers=headers, stralign='right') + + +def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: + """ + MetaInfo tracing API + + Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle, + and annotate them on ``gm.graph``. + + Uses: + >>> model = ... + >>> gm = symbolic_trace(model) + >>> args = ... # sample input to the ``GraphModule`` + >>> metainfo_trace(gm, *args) + + Args: + gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo. + verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False. + unit (str, optional): The unit of memory. Defaults to "MB". + + Returns: + torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. + """ + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + interp = MetaInfoProp(gm.to(device)) + if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) + kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) + interp.propagate(*args, **kwargs) + if verbose: + interp.summary(unit) + gm.to('cpu') + del interp + return gm diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f98fcd686ea44f42eeb91de7b3963748bdd65d70 --- /dev/null +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -0,0 +1,370 @@ +import torch +from torch.fx.graph_module import GraphModule +from typing import Callable, List, Dict, Any, Optional +from torch.fx._compatibility import compatibility +from packaging import version +from colossalai.fx.passes.meta_info_prop import TensorMetadata +import inspect +from typing import List +from colossalai.fx.passes.split_module import Partition +from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass +from torch.fx.node import Node + + +def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): + ''' + This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. + ''' + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for node in mod_graph.nodes: + if node.op == "call_module": + valid_children_size += 1 + valid_children.append(node.target) + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + accumulate_layer_amount = 0 + list_of_part = partition_list + part_index = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + if node.target in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == list_of_part[part_index]: + part_index += 1 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + + gm.recompile() + return gm + + +def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): + ''' + This pass will be used in gpt2 test, only a part of changes may be added into + split_with_split_nodes_pass, and it will be deprecated in future. + ''' + part_idx = 0 + + def eliminate_unused_placeholders(gm): + for node in gm.graph.nodes: + if node.op == 'placeholder': + if not len(node.users): + gm.graph.erase_node(node) + gm.recompile() + return gm + + def refill_outputs_and_placeholders(gm, next_partition_placeholders): + ''' + This method is used to eliminate the outputs in previous partition which is unused in next partition. + In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. + The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it + to partition 1 and partition 2. However, in single direction linked list, we need to do so. + ''' + output_type = None + output_args = [] + non_output_list = [] + new_placeholder_list = [] + for node in gm.graph.nodes: + if node.op == 'output': + if isinstance(node.args[0], (tuple, list)): + output_type = node.args[0].__class__ + output_args.extend([n.name for n in node.args[0]]) + else: + output_args.append(node.args[0].name) + rm_list = [] + for name in output_args: + if next_partition_placeholders and name not in next_partition_placeholders: + rm_list.append(name) + for name in rm_list: + output_args.remove(name) + gm.graph.erase_node(node) + else: + non_output_list.append(node.name) + + for name in next_partition_placeholders: + if name not in output_args: + output_args.append(name) + + for name in output_args: + if name not in non_output_list: + gm.graph.placeholder(name) + + # convert name to node for output_args + for index, name in enumerate(output_args): + for n in gm.graph.nodes: + if n.name == name: + output_args[index] = n + continue + + # reorder the output args to make sure + # output args has same order as next partition placeholder + reorder_output_args = [] + if next_partition_placeholders: + for name in next_partition_placeholders: + for node in output_args: + if node.name == name: + reorder_output_args.append(node) + continue + + for node in gm.graph.nodes: + if node.op == 'placeholder': + new_placeholder_list.append(node.name) + if output_type is not None: + gm.graph.output(output_type(output_args)) + else: + gm.graph.output(output_args) + gm.recompile() + return gm, new_placeholder_list + + def split_callback(n: torch.fx.Node): + nonlocal part_idx + if (n.op, n.target) == ('call_function', pipe_split): + part_idx += 1 + return part_idx + + split_mod = split_module_for_gpt2_test(annotated_gm, None, split_callback) + split_submodules = [] + for name, submodule in split_mod.named_modules(): + if isinstance(submodule, torch.fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ('call_function', pipe_split): + submodule.graph.erase_node(node) + submodule.recompile() + split_submodules.append(submodule) + + submodules = list(split_mod.children()) + placeholder_dict = {} + for submodule in submodules: + submodule = eliminate_unused_placeholders(submodule) + placeholder_dict[submodule] = [] + submodules.reverse() + for index, submodule in enumerate(submodules): + if index == 0: + placeholder_list = [] + else: + placeholder_list = placeholder_dict[submodules[index - 1]] + submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list) + submodule.recompile() + + split_mod.recompile() + + return split_mod, split_submodules + + +@compatibility(is_backward_compatible=True) +def split_module_for_gpt2_test( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[torch.fx.node.Node], int], +): + """ + This pass will be used in gpt2 pp performance test, only a part of changes may be added into + split_module, and it will be deprecated in future. + """ + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, torch.fx.node.Node] = {} + + def _node_with_all_tensor_element(node_metadata: Any) -> int: + """ + return whether node contains non-tensor element. + """ + all_tensor_node = True + + if isinstance(node_metadata, TensorMetadata): + all_tensor_node = node_metadata.is_tensor and all_tensor_node + elif isinstance(node_metadata, dict): + value_list = [v for _, v in node_metadata.items()] + all_tensor_node += _node_with_all_tensor_element(value_list) + else: + for element in node_metadata: + all_tensor_node += _node_with_all_tensor_element(element) + + return all_tensor_node + + def _move_all_ancestors_into_partition(node, partition_name): + all_ancestors = set() + + def _gen_all_ancestors_set(node): + all_ancestors.add(node) + for n in node.all_input_nodes: + if n in all_ancestors: + continue + _gen_all_ancestors_set(n) + + _gen_all_ancestors_set(node) + for n in list(all_ancestors): + if n.op != 'placeholder' and n._fx_partition > partition_name: + n._fx_partition = partition_name + + def record_cross_partition_use(def_node: torch.fx.node.Node, + use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, '_fx_partition', None) + use_partition_name = getattr(use_node, '_fx_partition', None) + if def_partition_name != use_partition_name: + # if 'tensor_meta' in def_node.meta: + # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): + # _move_all_ancestors_into_partition(use_node, def_partition_name) + # node_process_list.extend(use_node.all_input_nodes) + # node_process_list.extend(list(use_node.users)) + # node_process_list.append(use_node) + + # return + + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + + node_process_list = list(m.graph.nodes) + # split nodes into parititons + while node_process_list: + node = node_process_list.pop(0) + orig_nodes[node.name] = node + + if node.op in ["placeholder"]: + continue + if node.op == 'output': + # partition_name = str(split_callback(node)) + # def _set_output_args_partition(n, partition_name): + # n._fx_partition = partition_name + # torch.fx.graph.map_arg(node.args[0], lambda n: _set_output_args_partition(n, partition_name)) + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + continue + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + origin_partition_name = getattr(node, '_fx_partition', None) + if origin_partition_name is None: + node._fx_partition = partition_name + + torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.partitions_dependent_on): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].partition_dependents: + partitions[dependent].partitions_dependent_on.pop(root_partition) + if not partitions[dependent].partitions_dependent_on: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # add placeholders to parititons + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for input in partition.inputs: + placeholder = partition.graph.placeholder(input) + placeholder.meta = orig_nodes[input].meta.copy() + partition.environment[orig_nodes[input]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, '_fx_partition'): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) + + if node.op not in ['call_module', 'get_attr']: + target = node.target + else: + target_atoms = node.target.split('.') + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise RuntimeError(f'Operator target {node.target} not found!') + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = '_'.join(target_atoms) + partition.targets[target] = target_attr + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + new_node = partition.graph.create_node(op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + name=node.name) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Set up values to construct base module + base_mod_env: Dict[str, torch.fx.node.Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + for node in m.graph.nodes: + if node.op == 'placeholder': + if version.parse(torch.__version__) < version.parse('1.11.0'): + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) + else: + default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, + type_expr=node.type, + default_value=default_value) + base_mod_env[node.name].meta = node.meta.copy() + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + partition.graph.output(output_vals) + + # Construct GraphModule for this partition + submod_name = f'submod_{partition_name}' + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, + partition.graph) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) + if len(partition.outputs) > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + else: + if not partition.outputs: + continue + base_mod_env[list(partition.outputs)[0]] = output_val + + for node in m.graph.nodes: + if node.op == 'output': + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + + return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bad06bb45a1393543039ef55dea6c8d1e9f50b --- /dev/null +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +import operator +from colossalai.tensor import ProcessGroup +from colossalai.tensor.distspec import ShardSpec +from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, + operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout +] + + +def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: + """weight_split + split a nn.Parameter + + Args: + weight (torch.nn.parameter.Parameter): a torch Parameter instance + dim (int): the dimension to be sharded along with + col_normal(bool): col shard with gather or not + Returns: + _type_: _description_ + """ + if col_normal: + setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal")) + else: + setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs")) + return weight + + +def column_shard_linear_pass(gm: torch.fx.GraphModule): + # Split all the linear module with column shard. Currently for testing only. + mod_graph = gm.graph + for node in mod_graph.nodes: + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if isinstance(target_module, torch.nn.Linear): + target_module.weight = weight_split(target_module.weight, dim=0, col_normal=False) + if target_module.bias is not None: + target_module.bias.data = weight_split(target_module.bias.data, dim=0, col_normal=False) + + gm.recompile() + return gm + + +def row_shard_linear_pass(gm: torch.fx.GraphModule): + # Split all the linear module with row shard. Currently for testing only. + mod_graph = gm.graph + for node in mod_graph.nodes: + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if isinstance(target_module, torch.nn.Linear): + target_module.weight = weight_split(target_module.weight, dim=-1, col_normal=False) + + gm.recompile() + return gm + + +def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): + """ + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + """ + #TODO: Needs to handle special cases, like x = linear(x) + linear(x) + graph = graph_module.graph + world_size = process_group.world_size() + + def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): + # traverse the graph to look for consecutive linear layers + is_linear_module = False + + if node.op == 'call_module': + # look for the linear layer + module = node.graph.owning_module.get_submodule(node.target) + if isinstance(module, nn.Linear): + is_linear_module = True + if start_tracking: + # when start_tracking = True + # it means the first linear has been found and the current module + # is the second linear + # set the current linear module to be row-sharded + annotation_record['row'] = module + + for shard_type, module in annotation_record.items(): + # add row sharding spec + if shard_type == 'row': + dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) + comp_spec = ComputeSpec(ComputePattern.TP1D) + setattr(module.weight, 'pg', process_group) + setattr(module.weight, 'dist_spec', dist_spec) + setattr(module.weight, 'comp_spec', comp_spec) + elif shard_type == 'col': + weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + weight_comp_spec = ComputeSpec(ComputePattern.TP1D) + weight_comp_spec.output_replicate = False + setattr(module.weight, 'pg', process_group) + setattr(module.weight, 'dist_spec', weight_dist_spec) + setattr(module.weight, 'comp_spec', weight_comp_spec) + + if module.bias is not None: + bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + bias_comp_spec = ComputeSpec(ComputePattern.TP1D) + bias_comp_spec.output_replicate = False + setattr(module.bias, 'pg', process_group) + setattr(module.bias, 'dist_spec', bias_dist_spec) + setattr(module.bias, 'comp_spec', bias_comp_spec) + start_tracking = False + annotation_record.clear() + else: + # when start tracking = False + # it means the current layer is the first linear + # set the linear layer to be col-sharded + start_tracking = True + annotation_record['col'] = module + + if start_tracking and not is_linear_module: + # check against the white list + # if non-element wise op is found, we reset the tracking + if node.op == 'call_module': + module = node.graph.owning_module.get_submodule(node.target) + if module.__class__ not in ELEMENTWISE_MODULE_OP: + start_tracking = False + elif node.op == 'call_function' or node.op == 'call_method': + if node.target not in ELEMENTWISE_FUNC_OP: + start_tracking = False + elif len(node.users.keys()) > 1: + start_tracking = False + + if not start_tracking: + annotation_record.clear() + + # stop tracking for consecutive linear when branch is found + # e.g. + # out1 = self.linear1(x) + # out2 = self.linear2(x) + # return out1+out2 + next_nodes = list(node.users.keys()) + if len(next_nodes) > 1: + start_tracking = False + annotation_record.clear() + + # traverse + for node in next_nodes: + _traverse_and_annotate(node, start_tracking, annotation_record, world_size) + + placeholder_node = list(graph.nodes)[0] + annotate_record = {} + _traverse_and_annotate(placeholder_node, False, annotate_record, world_size) + + return graph_module diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py new file mode 100644 index 0000000000000000000000000000000000000000..bc257edc8c890f67e0bd66b45d6987bf9be71417 --- /dev/null +++ b/colossalai/fx/passes/split_module.py @@ -0,0 +1,297 @@ +import torch +from torch.fx.graph_module import GraphModule +from typing import Callable, List, Dict, Any, Optional +from torch.fx._compatibility import compatibility +from packaging import version +import inspect + + +@compatibility(is_backward_compatible=True) +class Partition: + """ + Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py + """ + + def __init__(self, name: str): + self.name: str = name + self.node_names: List[str] = [] + self.inputs: Dict[str, None] = {} + self.outputs: Dict[str, None] = {} + self.partitions_dependent_on: Dict[str, None] = {} + self.partition_dependents: Dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {} + self.targets: Dict[str, Any] = {} + + def __repr__(self) -> str: + return f"name: {self.name},\n" \ + f" nodes: {self.node_names},\n" \ + f" inputs: {self.inputs},\n" \ + f" outputs: {self.outputs},\n" \ + f" partitions depenent on: {self.partitions_dependent_on},\n" \ + f" parition dependents: {self.partition_dependents}" + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[torch.fx.node.Node], int], + merge_output = False, +): + """ + Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py + Creates subgraphs out of main graph + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[torch.fx.node.Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + Returns: + GraphModule: the module after split. + Example: + This is a sample setup: + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from colossalai.fx.passes.split_module import split_module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + Output looks like this. Original graph is broken into partitions + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + Output of split module is the same as output of input traced module. + This is an example within a test setting: + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, torch.fx.node.Node] = {} + + def record_cross_partition_use(def_node: torch.fx.node.Node, + use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, '_fx_partition', None) + use_partition_name = getattr(use_node, '_fx_partition', None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + + def record_output( + def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node] + ): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + use_partition.outputs.setdefault(def_node.name) + else: + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.outputs.setdefault(def_node.name) + + # split nodes into parititons + for node in m.graph.nodes: + orig_nodes[node.name] = node + + if node.op in ["placeholder"]: + continue + if node.op == 'output': + if merge_output: + torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) + else: + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + continue + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.partitions_dependent_on): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].partition_dependents: + partitions[dependent].partitions_dependent_on.pop(root_partition) + if not partitions[dependent].partitions_dependent_on: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # add placeholders to parititons + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for input in partition.inputs: + placeholder = partition.graph.placeholder(input) + placeholder.meta = orig_nodes[input].meta.copy() + partition.environment[orig_nodes[input]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, '_fx_partition'): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) + + if node.op not in ['call_module', 'get_attr']: + target = node.target + else: + target_atoms = node.target.split('.') + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise RuntimeError(f'Operator target {node.target} not found!') + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = '_'.join(target_atoms) + partition.targets[target] = target_attr + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + new_node = partition.graph.create_node(op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Set up values to construct base module + base_mod_env: Dict[str, torch.fx.node.Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + for node in m.graph.nodes: + if node.op == 'placeholder': + if version.parse(torch.__version__) < version.parse('1.11.0'): + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) + else: + default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, + type_expr=node.type, + default_value=default_value) + base_mod_env[node.name].meta = node.meta.copy() + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + partition.graph.output(output_vals) + + # Construct GraphModule for this partition + submod_name = f'submod_{partition_name}' + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, + partition.graph) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) + if len(partition.outputs) > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + else: + if not partition.outputs: + continue + base_mod_env[list(partition.outputs)[0]] = output_val + + for node in m.graph.nodes: + if node.op == 'output': + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + + return new_gm diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4f3cd6a4908177ca13f9d7fb82ff42b5ad1d5e --- /dev/null +++ b/colossalai/fx/passes/utils.py @@ -0,0 +1,172 @@ +import torch +from typing import Dict +from torch.fx.node import Node, map_arg +from torch.fx.graph import Graph + +def get_comm_size(prev_partition, next_partition): + """ + Given two partitions (parent and child), + calculate the communication size between the two. + """ + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + parent_node_names = [n.name for n in prev_partition.graph.nodes] + for node in next_partition.graph.nodes: + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + for n in input_nodes: + if n.name in parent_node_names and n not in visited_nodes: + comm_size += n.meta['tensor_meta'].numel + visited_nodes.add(n) + return comm_size + + +def get_leaf(graph: Graph): + """ + Given a graph, return leaf nodes of this graph. + Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, + we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. + """ + input_nodes: Dict[Node, None] = {} + for node in graph.nodes: + if node.op == 'output': + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + placeholder_nodes = [] + for node in input_nodes.keys(): + if node.op == 'placeholder': + placeholder_nodes.append(node) + for node in placeholder_nodes: + input_nodes.pop(node) + return list(input_nodes.keys()) + + +def is_leaf(graph: Graph, node: Node): + return node in get_leaf(graph) + + +def get_top(graph: Graph): + """ + Given a graph, return top nodes of this graph. + Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, + we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. + """ + top_node_list = set() + for node in graph.nodes: + if node.op == 'output': + continue + is_top = False + + def _get_top(node): + nonlocal is_top + if node.op == 'placeholder': + is_top = True + + map_arg(node.args, lambda n: _get_top(n)) + map_arg(node.kwargs, lambda n: _get_top(n)) + if is_top: + top_node_list.add(node) + return list(top_node_list) + + +def is_top(graph: Graph, node: Node): + return node in get_top(graph) + + +def get_all_consumers(graph: Graph, node: Node): + """ + Given a graph and a node of this graph, return all consumers of the node. + + Returns: + List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. + """ + consumer_list = [] + for n in graph.nodes: + if node in n.all_input_nodes: + consumer_list.append(n) + return consumer_list + + +def assign_bfs_level_to_nodes(graph: Graph): + """ + Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. + Example: + class MLP(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.linear5 = torch.nn.Linear(dim, dim) + def forward(self, x): + l1 = self.linear1(x) + l2 = self.linear2(x) + l3 = self.linear3(l1) + l4 = self.linear4(l2) + l5 = self.linear5(l3) + return l4, l5 + model = MLP(4) + gm = symbolic_trace(model) + print(gm.graph) + assign_bfs_level_to_nodes(gm.graph) + for node in gm.graph.nodes: + if hasattr(node, 'bfs_level'): + print(node.name, node.bfs_level) + + Output: + graph(): + %x : [#users=2] = placeholder[target=x] + %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) + %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {}) + %linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {}) + %linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {}) + %linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {}) + return (linear4, linear5) + linear1 0 + linear2 0 + linear3 1 + linear4 1 + linear5 2 + """ + current_level = 0 + nodes_to_process = [] + + top_nodes = get_top(graph) + for node in top_nodes: + node.bfs_level = current_level + nodes_to_process.extend(get_all_consumers(graph, node)) + + current_level += 1 + while nodes_to_process: + new_process_list = [] + for node in nodes_to_process: + if node.op == 'output': + continue + node.bfs_level = current_level + new_process_list.extend(get_all_consumers(graph, node)) + nodes_to_process = new_process_list + current_level += 1 + + +def get_node_module(node) -> torch.nn.Module: + """ + Find the module associated with the given node. + Args: + node (torch.fx.Node): a torch.fx.Node object in the fx computation graph + Returns: + torch.nn.Module: the module associated with the given node + """ + + assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' + assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' + module = node.graph.owning_module.get_submodule(node.target) + return module + diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcbde0eb23b806b7e37e407d57a962d8ff71573 --- /dev/null +++ b/colossalai/fx/profiler/__init__.py @@ -0,0 +1,18 @@ +from .._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from .opcount import flop_mapping + from .profiler import profile_function, profile_method, profile_module + from .shard_utils import ( + calculate_bwd_time, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, + ) + from .tensor import MetaTensor +else: + from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out + +from .dataflow import GraphInfo +from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5763a46dc83f19dadefbebb32dbcf9a59578a2b3 --- /dev/null +++ b/colossalai/fx/profiler/constants.py @@ -0,0 +1,44 @@ +import torch + +__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] + +aten = torch.ops.aten + +ALIAS_ATEN = [ + aten.detach.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + aten._reshape_alias.default, +] + +INPLACE_NEW = [ + aten.empty_like.default, + aten.new_empty_strided.default, +] + +INPLACE_MATH_ATEN = [ + aten.add_.Tensor, + aten.sub_.Tensor, + aten.div_.Tensor, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.bernoulli_.float, +] + +CLONE_ATEN = [ + aten.clone.default, +] + +# See illustrations in +# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py +OUTPUT_SAVED_OPS = [ + torch.nn.functional.relu, + torch.nn.functional.softmax, +] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e8880322b84f54e6c4742a821f4de76dfb664a --- /dev/null +++ b/colossalai/fx/profiler/dataflow.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import Dict, List + +from torch.fx import Graph, Node + +from .._compatibility import compatibility +from .memory_utils import activation_size, is_inplace + + +class Phase(Enum): + FORWARD = 0 + BACKWARD = 1 + PLACEHOLDER = 2 + + +@compatibility(is_backward_compatible=True) +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`. + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | \_____ | | + it is not saved for | | \ | | + backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node. + ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node. + fwd_time (float): The real forward time (s) of a certain node. + bwd_flop (int): The backward FLOPs of a certain node. + bwd_time (float): The real backward time (s) of a certain node. + save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. + fwd_in (List): See the above illustration. + fwd_tmp (List): See the above illustration. + fwd_out (List): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + fwd_mem_out (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + + # TODO(super-dainiu): removed redundant items, currently all of them are necessary for development + + fwd_flop: int = 0 + fwd_time: float = 0.0 + bwd_flop: int = 0 + bwd_time: float = 0.0 + save_fwd_in: bool = False + fwd_in: List = field(default_factory=list) + fwd_tmp: List = field(default_factory=list) + fwd_out: List = field(default_factory=list) + fwd_mem_tmp: int = 0 + fwd_mem_out: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + +def is_phase(n: Node, phase: Phase) -> bool: + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == phase + + +@compatibility(is_backward_compatible=False) +def autograd_graph_analysis(graph: Graph) -> GraphInfo: + """Analyze the autograd node dependencies and find out the memory usage. + Basically the input graph should have all nodes marked for keyword `phase`. + Nodes should have attribute `out` indicating the output of each node. + ============================================================================ + Placeholder ----> p o <---- We need to keep track of grad out + |\________ | + โ†“ โ†˜| + f --------> b + |\ \_____ โ†‘ + | \ โ†˜ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ โ†‘ + โ†˜ โ†˜| + f ----> b <---- Backward can be freed as soon as it is required no more. + โ†˜ โ†— + l + ============================================================================= + Args: + graph (Graph): The autograd graph with nodes marked for keyword `phase`. + + Returns: + graph_info (GraphInfo): Meta information for the dataflow. + """ + + def _peak_memory(deps: Dict[Node, int]): + peak_mem = 0 + for k, v in deps.items(): + if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): + peak_mem += activation_size(k.meta['saved_tensor']) + if v <= float('-inf') and is_phase(k, Phase.FORWARD): + peak_mem -= activation_size(k.meta['saved_tensor']) + return peak_mem + + # deps is used to track all the memory dependencies of the graph. + deps = {} + graph_info = GraphInfo() + + for n in graph.nodes: + n: Node + deps[n] = len(n.users) + # A forward tensor who is marked `save` but is also + # an input to `Phase.FORWARD` should be saved during forward. + # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. + # Any `fwd_mem_in` should be kept in memory even this function + # is checkpointed. + # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint + # the node, `fwd_mem_tmp` can be freed. + if is_phase(n, Phase.PLACEHOLDER): + graph_info.fwd_in += n.meta['saved_tensor'] + if is_phase(n, Phase.FORWARD): + graph_info.fwd_tmp += n.meta['saved_tensor'] + elif is_phase(n, Phase.BACKWARD): + if len(n.users): + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) + else: + # TODO: some of the bwd_mem_out might be model parameters. + # basically a backward node without user is a `grad_out` node + graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) + for input_n in n.all_input_nodes: + if input_n in deps: + deps[input_n] -= 1 + if deps[input_n] <= 0: + deps[input_n] = float('-inf') + return graph_info diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5387981e1921497d27c2acee9ebf4e310c3add7 --- /dev/null +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -0,0 +1,5 @@ +from .profiler import profile_function, profile_method, profile_module +from .profiler_function import * +from .profiler_module import * +from .registry import meta_profiler_function, meta_profiler_module +from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..57ff3fd91299b5bb8938125bf2d3243c9a9c4c2b --- /dev/null +++ b/colossalai/fx/profiler/experimental/constants.py @@ -0,0 +1,44 @@ +from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub + +import torch + +__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] + +# TODO fill out the inplace ops +INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + getattr, + torch.Tensor.cpu, +] + +# TODO: list all call_methods that are inplace here +INPLACE_METHOD = [ + 'transpose', + 'permute', + # TODO: reshape may return a copy of the data if the data is not contiguous + 'reshape', + 'dim', + 'flatten', + 'size', + 'view', + 'unsqueeze', + 'to', + 'type', + 'flatten', +] + +# TODO: list all call_methods that are not inplace here +NON_INPLACE_METHOD = [ + 'chunk', + 'contiguous', + 'expand', + 'mean', + 'split', +] diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..5c545260e72b723bfa54beacfb20def3e758413f --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Tuple + +import torch +from torch.fx.node import Argument, Target + +from ..._compatibility import compatibility +from ..memory_utils import activation_size +from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD +from .registry import meta_profiler_function, meta_profiler_module + +__all__ = ['profile_function', 'profile_module', 'profile_method'] + + +# this is for compatibility use +@compatibility(is_backward_compatible=True) +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + +CALL_FUNCTION_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler.experimental import meta_profiler_function +@meta_profiler_function.register(YOUR_FUNCTION) +def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" +CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' +CALL_MODULE_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler.experimental import meta_profiler_module +@meta_profiler_module.register(YOUR_MODULE) +def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" + + +@compatibility(is_backward_compatible=True) +def profile_function(target: 'Target') -> Callable: + """ + Wrap a `call_function` node or `torch.nn.functional` in order to + record the memory cost and FLOPs of the execution. + Unfortunately, backward memory cost and FLOPs are estimated results. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn.functional` are available. + + Examples: + >>> input = torch.rand(100, 100, 100, 100, device='meta') + >>> func = torch.nn.functional.relu + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_function.has(target) or meta_profiler_function.has( + target.__name__), CALL_FUNCTION_MSG.format(target) + + fwd_tmp = 0 + fwd_out = 0 + out = func(*args, **kwargs) + if target not in INPLACE_OPS and not kwargs.get('inplace', False): + fwd_out = activation_size(out) + if meta_profiler_function.has(target): + profiler = meta_profiler_function.get(target) + else: + profiler = meta_profiler_function.get(target.__name__) + fwd_flop, _ = profiler(*args, **kwargs) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + f.__name__ = target.__name__ + func = target + return f + + +@compatibility(is_backward_compatible=True) +def profile_method(target: 'Target') -> Callable: + """ + Wrap a `call_method` node + record the memory cost and FLOPs of the execution. + + Warnings: + This is not fully implemented and you may follow the error message to debug. + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # execute the method and return the result + assert isinstance(target, str), f'{target} instance is not str.' + + out = getattr(self_obj, target)(*args_tail, **kwargs) + assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( + target, INPLACE_METHOD, NON_INPLACE_METHOD) + # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. + fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) + fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) + return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + return f + + +@compatibility(is_backward_compatible=True) +def profile_module(module: torch.nn.Module) -> Callable: + """ + Wrap a `call_module` node or `torch.nn` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn` are available. + + Example: + >>> input = torch.rand(4, 3, 224, 224, device='meta') + >>> mod = torch.nn.Conv2d(3, 128, 3) + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) + + fwd_tmp = 0 + fwd_out = 0 + out = func(*args, **kwargs) + if getattr(module, 'inplace', False): + fwd_out = activation_size(out) + profiler = meta_profiler_module.get(type(module)) + fwd_flop, _ = profiler(module, *args, **kwargs) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + f.__name__ = module.__class__.__name__ + func = module.forward + return f diff --git a/colossalai/fx/profiler/experimental/profiler_function/__init__.py b/colossalai/fx/profiler/experimental/profiler_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf77edba859ecc568a5010287b8797fc31bb6701 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/__init__.py @@ -0,0 +1,8 @@ +from .activation_function import * +from .arithmetic import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .python_ops import * +from .torch_ops import * diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..a43aef063e197de12c23fc5a81fb13e8183eaae9 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py @@ -0,0 +1,33 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_function + +# TODO: different activation has different FLOPs count, currently unused. +_multiplier = { + torch.nn.functional.relu: 1, + torch.nn.functional.prelu: 4, + torch.nn.functional.sigmoid: 4, + torch.nn.functional.tanh: 5, + torch.nn.functional.leaky_relu: 3, + torch.nn.functional.elu: 4, + torch.nn.functional.relu6: 2, + torch.nn.functional.gelu: 9, + torch.nn.functional.hardswish: 5, + torch.nn.functional.hardsigmoid: 4, +} + + +@meta_profiler_function.register(torch.nn.functional.leaky_relu) +@meta_profiler_function.register(torch.nn.functional.elu) +@meta_profiler_function.register(torch.nn.functional.gelu) +@meta_profiler_function.register(torch.nn.functional.relu6) +@meta_profiler_function.register(torch.nn.functional.prelu) +@meta_profiler_function.register(torch.nn.functional.relu) +@meta_profiler_function.register(torch.nn.functional.sigmoid) +@meta_profiler_function.register(torch.nn.functional.tanh) +@meta_profiler_function.register(torch.nn.functional.hardswish) +@meta_profiler_function.register(torch.nn.functional.hardsigmoid) +def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf50133d3bd5965192bb34a07fbafc71e6c7903 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py @@ -0,0 +1,85 @@ +import operator +from functools import reduce +from typing import Any, Optional, Tuple, Union +import torch +from ..registry import meta_profiler_function + + +def _elementwise_flops_compute(input, other): + # copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763 + if not torch.is_tensor(input): + if torch.is_tensor(other): + return reduce(operator.mul, other.shape), 0 + else: + return 1, 0 + elif not torch.is_tensor(other): + return reduce(operator.mul, input.shape), 0 + else: + dim_input = len(input.shape) + dim_other = len(other.shape) + max_dim = max(dim_input, dim_other) + + final_shape = [] + for i in range(max_dim): + in_i = input.shape[i] if i < dim_input else 1 + ot_i = other.shape[i] if i < dim_other else 1 + if in_i > ot_i: + final_shape.append(in_i) + else: + final_shape.append(ot_i) + flops = reduce(operator.mul, final_shape) + return flops, 0 + + +@meta_profiler_function.register(torch.add) +@meta_profiler_function.register(torch.eq) +@meta_profiler_function.register(torch.sub) +@meta_profiler_function.register(torch.mul) +@meta_profiler_function.register(torch.floor_divide) +@meta_profiler_function.register('add') # for built-in op + +@meta_profiler_function.register('iadd') # for built-in op += +@meta_profiler_function.register('eq') # for built-in op = +@meta_profiler_function.register('sub') # for built-in op - +@meta_profiler_function.register('isub') # for built-in op -= +@meta_profiler_function.register('mul') # for built-in op * +@meta_profiler_function.register('imul') # for built-in op *= +@meta_profiler_function.register('floordiv') # for built-in op // +@meta_profiler_function.register('ifloordiv') # for built-in op //= +def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + return _elementwise_flops_compute(input, other) + + +@meta_profiler_function.register(torch.abs) +def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.matmul) +@meta_profiler_function.register('matmul') # for built-in op @ +@meta_profiler_function.register(torch.Tensor.matmul) +def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = reduce(operator.mul, input.shape) * other.shape[-1] + flops = 2 * macs + return flops, macs + + +@meta_profiler_function.register(torch.bmm) +def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = reduce(operator.mul, input.shape) * other.shape[-1] + flops = 2 * macs + return flops, macs + + +@meta_profiler_function.register(torch.var_mean) +def torch_var_mean(input: torch.Tensor, + dim: Union[int, Tuple[int, ...]], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, + *, + out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + assert out is None, 'saving to out is not supported yet' + flops = input.numel() * 3 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e43d781b8b64ab78cf3299daba3df1d17a5420 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py @@ -0,0 +1,19 @@ +import torch +from typing import Optional +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding( + input: torch.Tensor, + weight: torch.Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> torch.Tensor: + # F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..01fe4c87137083db2458c560a88cc6faa0af377e --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py @@ -0,0 +1,13 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.linear) +def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]: + out_features = weight.shape[0] + macs = torch.numel(input) * out_features + flops = 2 * macs + if bias is not None: + flops += bias.numel() + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ea508d70f80f33bbc8ae354e9743a4939d5e8c --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py @@ -0,0 +1,66 @@ +from typing import List, Optional, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.instance_norm) +def torch_nn_func_instancenorm( + input: torch.Tensor, + running_mean: Optional[torch.Tensor] = None, + running_var: Optional[torch.Tensor] = None, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +): + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.group_norm) +def torch_nn_func_groupnorm(input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5) -> Tuple[int, int]: + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm( + input: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> Tuple[int, int]: + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm( + input: torch.Tensor, + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tuple[int, int]: + has_affine = weight is not None + if training: + flops = input.numel() * (2 if has_affine else 1) + else: + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..a639f5ee83c1f4d2b75a3c120ea6ae3884fc422f --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py @@ -0,0 +1,22 @@ +from typing import Tuple, Union +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.avg_pool1d) +@meta_profiler_function.register(torch.nn.functional.avg_pool2d) +@meta_profiler_function.register(torch.nn.functional.avg_pool3d) +@meta_profiler_function.register(torch.nn.functional.max_pool1d) +@meta_profiler_function.register(torch.nn.functional.max_pool2d) +@meta_profiler_function.register(torch.nn.functional.max_pool3d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d) +def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]: + # all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217) + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8561206ba0e7a874b202a31dd13a040533d1db --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -0,0 +1,18 @@ +import operator +from typing import Any, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(operator.getitem) +def operator_getitem(a: Any, b: Any) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs + + +@meta_profiler_function.register(getattr) +def python_getattr(a: Any, b: Any) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..abdd7ad565ba237d7d6eab9e3c9b77d7afb10abf --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py @@ -0,0 +1,60 @@ +from functools import reduce +import operator +from typing import Any, Optional, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.arange) +@meta_profiler_function.register(torch.finfo) +@meta_profiler_function.register(torch.permute) +@meta_profiler_function.register(torch.Tensor.permute) +@meta_profiler_function.register(torch.Tensor.repeat) +@meta_profiler_function.register(torch.index_select) +@meta_profiler_function.register(torch.Tensor.index_select) +@meta_profiler_function.register(torch.squeeze) +@meta_profiler_function.register(torch.Tensor.squeeze) +@meta_profiler_function.register(torch.unsqueeze) +@meta_profiler_function.register(torch.Tensor.unsqueeze) +@meta_profiler_function.register(torch.cat) +@meta_profiler_function.register(torch.concat) +@meta_profiler_function.register(torch.repeat_interleave) +@meta_profiler_function.register(torch.Tensor.repeat_interleave) +@meta_profiler_function.register(torch.flatten) +@meta_profiler_function.register(torch.Tensor.flatten) +@meta_profiler_function.register(torch.roll) +@meta_profiler_function.register(torch.full) +@meta_profiler_function.register(torch.Tensor.cpu) +@meta_profiler_function.register(torch.Tensor.cuda) +@meta_profiler_function.register(torch._assert) +def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.where) +def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + flops = condition.numel() + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.max) +def torch_max(input: torch.Tensor, + dim: int = None, + keepdim: bool = False, + *, + out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = 0 + assert out is None, 'assigning value to out is not supported yet' + if dim is not None: + shape = list(input.shape) + shape.pop(int(dim)) + flops = reduce(operator.mul, shape), macs + return flops, macs + else: + flops = input.numel() + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/__init__.py b/colossalai/fx/profiler/experimental/profiler_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fe646f3695e63d285b9a32530cd70f10187f34 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/__init__.py @@ -0,0 +1,10 @@ +from .activation_function import * +from .attention import * +from .convolution import * +from .dropout import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .rnn import * +from .torch_op import * diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..2ebf514ad2699cc4e71741b9c3e143cedcb63041 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py @@ -0,0 +1,33 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + +# TODO: different activation has different FLOPs count, currently unused. +_multiplier = { + torch.nn.ReLU: 1, + torch.nn.PReLU: 4, + torch.nn.Sigmoid: 4, + torch.nn.Tanh: 5, + torch.nn.LeakyReLU: 3, + torch.nn.ELU: 4, + torch.nn.ReLU6: 2, + torch.nn.GELU: 9, + torch.nn.Hardswish: 5, + torch.nn.Hardsigmoid: 4, +} + + +@meta_profiler_module.register(torch.nn.ELU) +@meta_profiler_module.register(torch.nn.LeakyReLU) +@meta_profiler_module.register(torch.nn.ReLU) +@meta_profiler_module.register(torch.nn.GELU) +@meta_profiler_module.register(torch.nn.Sigmoid) +@meta_profiler_module.register(torch.nn.Tanh) +@meta_profiler_module.register(torch.nn.ReLU6) +@meta_profiler_module.register(torch.nn.PReLU) +@meta_profiler_module.register(torch.nn.Hardswish) +@meta_profiler_module.register(torch.nn.Hardsigmoid) +def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8daf74b232bf91d41933a2184e7b0c30d516d51a --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py @@ -0,0 +1,81 @@ +from typing import Optional, Tuple +import torch +from ..registry import meta_profiler_module + + +# TODO: This is hard to compute memory cost +@meta_profiler_module.register(torch.nn.MultiheadAttention) +def torch_nn_msa(self: torch.nn.MultiheadAttention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True) -> Tuple[int, int]: + if getattr(self, 'batch_first', False): + batch_size = query.shape[0] + len_idx = 1 + else: + batch_size = query.shape[1] + len_idx = 0 + dim_idx = 2 + + qdim = query.shape[dim_idx] + kdim = key.shape[dim_idx] + vdim = value.shape[dim_idx] + + qlen = query.shape[len_idx] + klen = key.shape[len_idx] + vlen = value.shape[len_idx] + + num_heads = self.num_heads + assert qdim == self.embed_dim + + if self.kdim is None: + assert kdim == qdim + if self.vdim is None: + assert vdim == qdim + + flops = 0 + macs = 0 + + # Q scaling + flops += qlen * qdim + + # Initial projections + flops += 2 * ((qlen * qdim * qdim) # QW + + (klen * kdim * kdim) # KW + + (vlen * vdim * vdim) # VW + ) + + macs += ((qlen * qdim * qdim) # QW + + (klen * kdim * kdim) # KW + + (vlen * vdim * vdim) # VW + ) + + if self.in_proj_bias is not None: + flops += (qlen + klen + vlen) * qdim + + # attention heads: scale, matmul, softmax, matmul + qk_head_dim = qdim // num_heads + v_head_dim = vdim // num_heads + + head_flops = ( + 2 * (qlen * klen * qk_head_dim) # QK^T + + (qlen * klen) # softmax + + 2 * (qlen * klen * v_head_dim) # AV + ) + head_macs = ((qlen * klen * qk_head_dim) # QK^T + + 2 * (qlen * klen * v_head_dim) # AV + ) + + flops += num_heads * head_flops + macs += num_heads * head_flops + + # final projection, bias is always enabled + flops += qlen * vdim * (vdim + 1) + + flops *= batch_size + macs *= batch_size + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..3193489fee5ec2f3751aae0738deb70927b6b401 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py @@ -0,0 +1,152 @@ +import operator +from functools import reduce +import math +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + c_in, l_in = input.shape[-2:] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + c_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html + c_in, d_in, h_in, w_in = input.shape[-4:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose1d) +def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + c_in, l_in = input.shape[-2:] + c_out = self.out_channels + l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce( + operator.mul, input.shape + ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += reduce(operator.mul, result_shape) + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose2d) +def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + c_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, input.shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += reduce(operator.mul, result_shape) + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose3d) +def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html + c_in, d_in, h_in, w_in = input.shape[-4:] + c_out = self.out_channels + d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, input.shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += reduce(operator.mul, result_shape) + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..417e0ed468637a5ce049ffa8137a73e5b266c971 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py @@ -0,0 +1,11 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Dropout) +def torch_nn_dropout(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/embedding.py b/colossalai/fx/profiler/experimental/profiler_module/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..dca6f9453af3ca70195a5be84fc5ce665a3a32db --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/embedding.py @@ -0,0 +1,11 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Embedding) +def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]: + # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs \ No newline at end of file diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ffb6f244d2ed7d5764339d61fdb46f71ae59a2 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py @@ -0,0 +1,14 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Linear) +@meta_profiler_module.register(torch.nn.modules.linear.NonDynamicallyQuantizableLinear) +def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]: + out_features = self.weight.shape[0] + macs = input.numel() * out_features + flops = 2 * macs + if self.bias is not None: + flops += self.bias.numel() + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..e9939da7b1c4172613867a70af76b9e27f3215f7 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py @@ -0,0 +1,33 @@ +from typing import Tuple, Union +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.InstanceNorm1d) +@meta_profiler_module.register(torch.nn.InstanceNorm2d) +@meta_profiler_module.register(torch.nn.InstanceNorm3d) +@meta_profiler_module.register(torch.nn.LayerNorm) +@meta_profiler_module.register(torch.nn.GroupNorm) +@meta_profiler_module.register(torch.nn.BatchNorm1d) +@meta_profiler_module.register(torch.nn.BatchNorm2d) +@meta_profiler_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: + # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 + has_affine = self.weight is not None + if self.training: + flops = input.numel() * (2 if has_affine else 1) + else: + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +try: + import apex + meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except (ImportError, AttributeError): + pass diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..e429ac3eea28055f42af2ea8f84663a5a6fd2a83 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py @@ -0,0 +1,22 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.AvgPool1d) +@meta_profiler_module.register(torch.nn.AvgPool2d) +@meta_profiler_module.register(torch.nn.AvgPool3d) +@meta_profiler_module.register(torch.nn.MaxPool1d) +@meta_profiler_module.register(torch.nn.MaxPool2d) +@meta_profiler_module.register(torch.nn.MaxPool3d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d) +def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + # all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217) + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..6e733d6da9156db13b2bac35af63a52ad89ad5a3 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py @@ -0,0 +1,75 @@ +from functools import reduce +import operator +import torch +from ..registry import meta_profiler_module +from typing import Optional, Tuple, Union + + +def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, + w_hh: torch.Tensor) -> Tuple[int, int]: + # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py + + # matrix matrix mult ih state and internal state + macs += reduce(operator.mul, w_ih.shape) + flops += 2 * reduce(operator.mul, w_ih.shape) + # matrix matrix mult hh state and internal state + macs += reduce(operator.mul, w_hh.shape) + flops += 2 * reduce(operator.mul, w_hh.shape) + if isinstance(module, (torch.nn.RNN, torch.nn.RNNCell)): + # add both operations + flops += module.hidden_size + elif isinstance(module, (torch.nn.GRU, torch.nn.GRUCell)): + # hadamard of r + flops += module.hidden_size + # adding operations from both states + flops += module.hidden_size * 3 + # last two hadamard product and add + flops += module.hidden_size * 3 + elif isinstance(module, (torch.nn.LSTM, torch.nn.LSTMCell)): + # adding operations from both states + flops += module.hidden_size * 4 + # two hadamard product and add for C state + flops += module.hidden_size * 3 + # final hadamard + flops += module.hidden_size * 3 + return flops, macs + + +@meta_profiler_module.register(torch.nn.LSTM) +@meta_profiler_module.register(torch.nn.GRU) +@meta_profiler_module.register(torch.nn.RNN) +def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = 0 + macs = 0 + for i in range(self.num_layers): + w_ih = self.__getattr__('weight_ih_l' + str(i)) + w_hh = self.__getattr__('weight_hh_l' + str(i)) + flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) + if self.bias: + b_ih = self.__getattr__('bias_ih_l' + str(i)) + b_hh = self.__getattr__('bias_hh_l' + str(i)) + flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) + flops *= reduce(operator.mul, input.shape[:2]) + macs *= reduce(operator.mul, input.shape[:2]) + if self.bidirectional: + flops *= 2 + macs *= 2 + return flops, macs + + +@meta_profiler_module.register(torch.nn.LSTMCell) +@meta_profiler_module.register(torch.nn.GRUCell) +@meta_profiler_module.register(torch.nn.RNNCell) +def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = 0 + macs = 0 + w_ih = self.__getattr__('weight_ih_l') + w_hh = self.__getattr__('weight_hh_l') + flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) + if self.bias: + b_ih = self.__getattr__('bias_ih_l') + b_hh = self.__getattr__('bias_hh_l') + flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) + flops *= input.shape[0] + macs *= input.shape[0] + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d3aed874eb10af76dc94e21b23c566178afe6264 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py @@ -0,0 +1,11 @@ +import operator +import torch +from ..registry import meta_profiler_module +from typing import Optional, Tuple, Union + + +@meta_profiler_module.register(torch.nn.Flatten) +def torch_nn_flatten(self: torch.nn.Flatten, input: torch.Tensor) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7d73bce321e43d7c1284bf4d78dbac2bc7c4abfc --- /dev/null +++ b/colossalai/fx/profiler/experimental/registry.py @@ -0,0 +1,25 @@ +class ProfilerRegistry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') +meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e53ed0bf8ec657d8916d49c8c4b97f1d996010a --- /dev/null +++ b/colossalai/fx/profiler/experimental/shard_utils.py @@ -0,0 +1,48 @@ +# for PyTorch 1.11 compatibility uses +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx import GraphModule, Node + +from ..._compatibility import compatibility + +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] + + +@compatibility(is_backward_compatible=True) +def calculate_fwd_in(n: Node) -> bool: + """A helper function to calculate `fwd_in` + + Args: + n (Node): a node from the graph + + Returns: + save_fwd_in (bool): the result of `save_fwd_in` + """ + return n.meta['save_fwd_in'] + + +@compatibility(is_backward_compatible=True) +def calculate_fwd_tmp(n: Node) -> int: + """A helper function to calculate `fwd_tmp` + + Args: + n (Node): a node from the graph + + Returns: + fwd_tmp (int): the result of `fwd_tmp` + """ + return n.meta["fwd_mem_tmp"] + + +@compatibility(is_backward_compatible=True) +def calculate_fwd_out(n: Node) -> int: + """A helper function to calculate `fwd_out` + + Args: + n (Node): a node from the graph + + Returns: + fwd_out (int): the result of `fwd_out` + """ + return n.meta['fwd_mem_out'] diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccbcb01cdc14045fbdd4906f1fc6f2a5ad728db --- /dev/null +++ b/colossalai/fx/profiler/memory_utils.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx import GraphModule, Node + +from .._compatibility import compatibility, is_compatible_with_meta + +__all__ = ['activation_size', 'parameter_size', 'is_inplace'] + + +@compatibility(is_backward_compatible=True) +def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Calculate activation size of a node. + + Args: + activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`. + + Returns: + int: The activation size, unit is byte. + """ + act_size = 0 + if isinstance(out, torch.Tensor): + if out.is_quantized: + act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size() + else: + act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() + elif isinstance(out, dict): + value_list = [v for _, v in out.items()] + act_size += activation_size(value_list) + elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): + for element in out: + act_size += activation_size(element) + return act_size + + +@compatibility(is_backward_compatible=True) +def parameter_size(mod: torch.nn.Module) -> int: + """Calculate parameter size of a node. + + Args: + mod (torch.nn.Module): The target `torch.nn.Module`. + + Returns: + int: The parameter size, unit is byte. + """ + param_size = 0 + for param in mod.parameters(): + param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + return param_size + + +def is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + + Args: + node (Node): torch.fx.Node + + Returns: + bool: indicates whether this op is inplace + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + if is_compatible_with_meta(): + from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: + inplace = True + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + + return inplace diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8db54a478e3957025ea1d9d96f1f97c5003d4f --- /dev/null +++ b/colossalai/fx/profiler/opcount.py @@ -0,0 +1,318 @@ +# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py +# ideas from https://pastebin.com/AkvAyJBw + +import operator +from functools import partial, reduce +from numbers import Number +from typing import Any, Callable, List + +import torch + +aten = torch.ops.aten + + +def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for matmul. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + input_shapes = [v.shape for v in inputs] + assert len(input_shapes) == 2, input_shapes + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] + return flops + + +def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for fully connected layers. + """ + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [v.shape for v in inputs[1:3]] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [input feature dimension, output feature dimension] + assert len(input_shapes[0]) == 2, input_shapes[0] + assert len(input_shapes[1]) == 2, input_shapes[1] + batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][1] + flops = batch_size * input_dim * output_dim + return flops + + +def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the aten::linear operator. + """ + # Inputs is a list of length 3; unlike aten::addmm, it is the first + # two elements that are relevant. + input_shapes = [v.shape for v in inputs[0:2]] + # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] + # input_shapes[1]: [output_feature_dim, input_feature_dim] + assert input_shapes[0][-1] == input_shapes[1][-1] + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0] + return flops + + +def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the bmm operation. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + assert len(inputs) == 2, len(inputs) + input_shapes = [v.shape for v in inputs] + n, c, t = input_shapes[0] + d = input_shapes[-1][-1] + flops = n * c * t * d + return flops + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> Number: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape) + return flops + + +def conv_flop_jit(inputs: List[Any], outputs: List[Any]): + """ + Count flops for convolution. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape) + transposed = inputs[6] + + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]): + grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]] + output_mask = inputs[-1] + fwd_transposed = inputs[7] + flop_count = 0 + + if output_mask[0]: + grad_input_shape = outputs[0].shape + flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed) + if output_mask[1]: + grad_weight_shape = outputs[1].shape + flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed) + + return flop_count + + +def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: + """ + Args: + affine_arg_index: index of the affine argument in inputs + """ + + def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for norm layers. + """ + # Inputs[0] contains the shape of the input. + input_shape = inputs[input_arg_index].shape + + has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], + 'shape') else inputs[affine_arg_index] + assert 2 <= len(input_shape) <= 5, input_shape + # 5 is just a rough estimate + flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) + return flop + + return norm_flop_jit + + +def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number: + if training is None: + training = inputs[-3] + assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" + if training: + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + has_affine = inputs[1].shape is not None + input_shape = reduce(operator.mul, inputs[0].shape) + return input_shape * (2 if has_affine else 1) + + +def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable: + """ + Count flops by + input_tensor.numel() * input_scale + output_tensor.numel() * output_scale + Args: + input_scale: scale of the input tensor (first argument) + output_scale: scale of the output tensor (first element in outputs) + """ + + def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: + ret = 0 + if input_scale != 0: + shape = inputs[0].shape + ret += input_scale * reduce(operator.mul, shape) if shape else 0 + if output_scale != 0: + shape = outputs[0].shape + ret += output_scale * reduce(operator.mul, shape) if shape else 0 + return ret + + return elementwise_flop + + +def zero_flop_jit(*args): + """ + Count flops for zero flop layers. + """ + return 0 + + +flop_mapping = { + # gemm + aten.mm.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, + + # convolution + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, + + # normalization + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + + # pooling + aten.avg_pool1d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten.avg_pool3d.default: elementwise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool1d.default: elementwise_flop_counter(1, 0), + aten.max_pool2d.default: elementwise_flop_counter(1, 0), + aten.max_pool3d.default: elementwise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), + aten.embedding.default: elementwise_flop_counter(1, 0), +} + +elementwise_flop_aten = [ + # basic op + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul.Scalar, + aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, + + # activation op + aten.hardswish.default, + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, + aten.threshold_backward.default, + + # dropout + aten.native_dropout.default, + aten.native_dropout_backward.default, +] + +for op in elementwise_flop_aten: + flop_mapping[op] = elementwise_flop_counter(1, 0) + +# TODO: this will be removed in future +zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.bernoulli_.float, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.where.self, + aten.zero_.default, + aten.zeros_like.default, +] + +for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..c87cd4321d31c59ebe369a2903b679a439fb4f96 --- /dev/null +++ b/colossalai/fx/profiler/profiler.py @@ -0,0 +1,409 @@ +import time +from functools import partial +from typing import Any, Callable, Dict, Tuple + +import torch +from torch.fx import Graph, Node +from torch.fx.node import Argument, Target +from torch.nn.parameter import Parameter +from torch.utils._pytree import tree_map + +from .._compatibility import compatibility +from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS +from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase +from .memory_utils import activation_size, parameter_size +from .opcount import flop_mapping +from .tensor import MetaTensor + +__all__ = ['profile_function', 'profile_module', 'profile_method'] + +# super-dainiu: this cache should be global, otherwise it cannot +# track duplicated tensors between nodes +cache = set() + +# a global identifier for inplace ops +do_not_cache = False + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + +def detach_variables(x): + if isinstance(x, torch.Tensor): + requires_grad = x.requires_grad + x = x.detach() + x.requires_grad = requires_grad + + return x + + +@compatibility(is_backward_compatible=True) +def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: + """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 + To profile the actual forward memory, we first run target in the context torch.no_grad() to get + the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory + by memory allocated minus the fwd_mem_out. + To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then + find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size + of args and kwargs). + We also add time stamps to profile the real forward and backward time. + + Args: + target (Callable): A Callable function + args (Any): Arguments + kwargs (Any): Arguments + + Returns: + Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward + time. + """ + + graphinfo = GraphInfo() + + # detach input from the graph + args = tree_map(detach_variables, args) + kwargs = tree_map(detach_variables, kwargs) + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # calculate fwd_mem_out + mem_stamp0 = torch.cuda.memory_allocated() + with torch.no_grad(): + out = getattr(self_obj, target)(*args_tail, **kwargs) + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 + del out + + # calculate fwd_mem_tmp & fwd_time + mem_stamp0 = torch.cuda.memory_allocated() + fwd_time0 = time.time() + out = getattr(self_obj, target)(*args_tail, **kwargs) + fwd_time1 = time.time() + graphinfo.fwd_time = fwd_time1 - fwd_time0 + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out + + # calculate bwd_mem_tmp & bwd_time + grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + bwd_time0 = time.time() + torch.autograd.backward(out, grad_tensors=grad_tensors) + bwd_time1 = time.time() + graphinfo.bwd_time = bwd_time1 - bwd_time0 + mem_stamp1 = torch.cuda.max_memory_allocated() + + # calculate bwd memory stats + # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation + graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) + graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 + graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out + + else: + # calculate fwd_mem_out + mem_stamp0 = torch.cuda.memory_allocated() + with torch.no_grad(): + out = target(*args, **kwargs) + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 + del out + + # calculate fwd_mem_tmp & fwd_time + mem_stamp0 = torch.cuda.memory_allocated() + fwd_time0 = time.time() + out = target(*args, **kwargs) + fwd_time1 = time.time() + graphinfo.fwd_time = fwd_time1 - fwd_time0 + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out + + # calculate bwd_mem_tmp & bwd_time + grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + bwd_time0 = time.time() + torch.autograd.backward(out, grad_tensors=grad_tensors) + bwd_time1 = time.time() + graphinfo.bwd_time = bwd_time1 - bwd_time0 + mem_stamp1 = torch.cuda.max_memory_allocated() + + # calculate bwd memory stats + # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation + graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) + graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 + graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out + + return tree_map(detach_variables, out), graphinfo + + +@compatibility(is_backward_compatible=False) +def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: + """ + Profile a Callable function with args and kwargs on meta devices. + + Args: + target (Callable): A Callable function + args (Any): Argument + kwargs (Any): Argument + + Returns: + out (Tuple[Any, ...]): The argument value that was retrieved. + meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + # This subgraph traces aten level ops inside one node. + subgraph = Graph() + + # `flop_count`` serves as a global dictionary to store results. + flop_count = { + Phase.FORWARD: 0, + Phase.BACKWARD: 0, + } + + # FlopTensor not only get the flop statistics of a single node, + # it also build a full autograd graph for this node. + # This makes sure we can analyze the dependencies of memory, and + # decide which forward intermediate results should be kept until + # backward is executed. + # Hopefully, this attempt will provide a better estimation of memory. + class FlopTensor(MetaTensor): + + _node: Node = None + + def __repr__(self): + if self.grad_fn: + return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})" + return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) + kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) + node = subgraph.create_node('call_function', func, args_node, kwargs_node) + + out = super().__torch_dispatch__(func, types, args, kwargs) + + flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) + node.meta['phase'] = phase + + # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, + # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during + # `Phase.FORWARD` + if phase == Phase.FORWARD: + if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: + node.meta['phase'] = Phase.PLACEHOLDER + + # TODO(yby): specify `saved_tensors` for backward memory estimation + node.meta['saved_tensor'] = [] + if phase == Phase.BACKWARD: + node.meta['saved_tensor'] = normalize_tuple(out) + + def wrap(x): + if isinstance(x, MetaTensor): + x = FlopTensor(x) + x._node = node + return x + + out = tree_map(wrap, out) + return out + + def wrap(x): + if isinstance(x, torch.Tensor): + x = FlopTensor(x) + if is_autogradable(x): + x.requires_grad_(True) + x._node = subgraph.create_node('placeholder', + 'placeholder', (subgraph._root,), + name=subgraph._graph_namespace.create_name('input', x._tensor)) + x._node.meta['phase'] = Phase.PLACEHOLDER + x._node.meta['saved_tensor'] = [] + return x + + # Basically, we need to detach the args and kwargs from the outer graph. + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + def pack(x): + global cache, do_not_cache + if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: + tensor = x._tensor.detach() + tensor.data_ptr = x._tensor.data_ptr + x._node.meta['saved_tensor'] += [tensor] + if not do_not_cache: + cache.add(x._tensor.data_ptr()) + return x + + def unpack(x): + return x + + # `phase` will mark the phase of autograd from outside scope. + phase = Phase.FORWARD + # mark saved tensors with saved_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + out = getattr(self_obj, target)(*args_tail, **kwargs) + else: + out = target(*args, **kwargs) + + # If the output is not a floating point `torch.Tensor` or it does not + # requires grad, then we should not run backward for this node. + if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))): + grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)] + phase = Phase.BACKWARD + torch.autograd.backward( + out, + grad_out, + ) + + graph_info = autograd_graph_analysis(subgraph) + graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] + + def extract_tensor(x: Any): + if isinstance(x, MetaTensor): + tensor = x._tensor.detach() + tensor.data_ptr = x._tensor.data_ptr + return tensor + if not isinstance(x, torch.finfo): + return x + + graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out))) + + def unwrap(x): + return MetaTensor(x) if isinstance(x, torch.Tensor) else x + + return tree_map(unwrap, out), graph_info + + +@compatibility(is_backward_compatible=True) +def profile_function(target: 'Target', device: str = 'meta') -> Callable: + """ + Wrap a `call_function` node or `torch.nn.functional` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn.functional` are available. + + Examples: + >>> input = torch.rand(100, 100, 100, 100, device='meta') + >>> func = torch.nn.functional.relu + >>> output, meta_info = profile_function(func)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # find the grad for parameter in args and kwargs + param_size = 0 + + def get_param_size(x): + nonlocal param_size + if isinstance(x, Parameter): + param_size += activation_size(x) + + tree_map(get_param_size, args) + tree_map(get_param_size, kwargs) + + # If there is an argument that this `call_function` is inplace, we should + # still run the profiling but discard some results regarding `target` + global do_not_cache + + inplace = kwargs.get('inplace', False) + if target in OUTPUT_SAVED_OPS: + do_not_cache = True + if inplace: + do_not_cache = True + kwargs['inplace'] = False + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + else: + out, meta = _profile_concrete(func, *args, **kwargs) + if inplace: + kwargs['inplace'] = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 + do_not_cache = False + + meta.bwd_mem_out -= param_size + return out, meta + + f.__name__ = target.__name__ + func = target + return f + + +@compatibility(is_backward_compatible=True) +def profile_method(target: 'Target', device: str = 'meta') -> Callable: + """ + Wrap a `call_method` node + record the memory cost and FLOPs of the execution. + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + # execute the method and return the result + assert isinstance(target, str), f'{target} instance is not str.' + if device == 'meta': + out, meta = _profile_meta(target, *args, **kwargs) + else: + out, meta = _profile_concrete(target, *args, **kwargs) + return out, meta + + return f + + +@compatibility(is_backward_compatible=True) +def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: + """ + Wrap a `call_module` node or `torch.nn` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn` are available. + + Example: + >>> input = torch.rand(4, 3, 224, 224, device='meta') + >>> mod = torch.nn.Conv2d(3, 128, 3) + >>> output, meta_info = profile_module(mod)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # calculate parameter size + param_size = parameter_size(module) + + # If there is an argument that this `call_module` is inplace, we should + # still run the profiling but discard some results regarding `module`. + global do_not_cache + + inplace = getattr(module, 'inplace', False) + if type(module) in OUTPUT_SAVED_MOD: + do_not_cache = True + if inplace: + do_not_cache = True + module.inplace = False + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + else: + out, meta = _profile_concrete(func, *args, **kwargs) + if inplace: + module.inplace = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 + do_not_cache = False + + # grad for param will not be counted + meta.bwd_mem_out -= param_size + return out, meta + + f.__name__ = module.__class__.__name__ + func = module.forward + return f diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a765e5055b28746c123fd0d268ce91874d6f0fd7 --- /dev/null +++ b/colossalai/fx/profiler/shard_utils.py @@ -0,0 +1,114 @@ +import torch +from torch.fx import Node + +from .._compatibility import compatibility, is_compatible_with_meta +from .memory_utils import activation_size + +if is_compatible_with_meta(): + from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS + +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] + + +@compatibility(is_backward_compatible=False) +def calculate_fwd_in(n: Node) -> int: + """A helper function to calculate `fwd_in` (with sharding spec) + + Args: + n (Node): a node from the graph + + Returns: + fwd_in (int): the result of `fwd_in` + """ + # TODO(super-dainiu): should divide the memory by sharding spec + return activation_size(n.meta["fwd_in"]) + + +@compatibility(is_backward_compatible=False) +def calculate_fwd_tmp(n: Node) -> int: + """A helper function to calculate `fwd_tmp` (with sharding spec) + Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. + + Args: + n (Node): a node from the graph + + Returns: + fwd_tmp (int): the result of `fwd_tmp` + """ + + # TODO(super-dainiu): should divide the memory by sharding spec + def is_relu_like_node(n: Node) -> bool: + """Check if a node is a ReLU-like node. + ReLU-like nodes have the following properties: + - They are either `call_function` or `call_module` + - Their output tensors are directly saved for backward + - Their input tensors are not saved for backward + + An example is `torch.nn.functional.softmax` which has (forward + backward): + def forward(self, input_2): + _softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None + zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None) + detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None + _softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None + detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None + detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None + + Args: + n (Node): A node from the graph + + Returns: + bool: Whether the node is a ReLU-like node + """ + if n.op == 'call_function': + return n.target in OUTPUT_SAVED_OPS + elif n.op == 'call_module': + return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD + return False + + if not is_relu_like_node(n): + return activation_size(n.meta["fwd_tmp"]) + return 0 + + +@compatibility(is_backward_compatible=False) +def calculate_fwd_out(n: Node) -> int: + """A helper function to calculate `fwd_out` (with sharding spec) + + Args: + n (Node): a node from the graph + + Returns: + fwd_out (int): the result of `fwd_out` + """ + + # TODO(super-dainiu): should divide the memory by sharding spec + def intersect(a, b): + return {k: a[k] for k in a if k in b} + + fwd_in = dict() + for u in n.users: + fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)}) + fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)} + return activation_size(intersect(fwd_in, fwd_out)) + + +def calculate_fwd_time(n: Node) -> float: + """A helper function to calculate `fwd_time` (with sharding spec) + Args: + n (Node): a node from the graph + Returns: + fwd_time (float): the result of `fwd_time` + """ + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["fwd_flop"] + + +def calculate_bwd_time(n: Node) -> float: + """A helper function to calculate `bwd_time` (with sharding spec) + Args: + n (Node): a node from the graph + Returns: + bwd_time (float): the result of `bwd_time` + """ + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["bwd_flop"] diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..43165305f010cf61c5e0538c8ce95106c0894e5a --- /dev/null +++ b/colossalai/fx/profiler/tensor.py @@ -0,0 +1,140 @@ +import uuid +from copy import deepcopy +from typing import Optional + +import torch +from torch.types import _bool, _device, _dtype +from torch.utils._pytree import tree_flatten, tree_map + +from .._compatibility import compatibility +from .constants import ALIAS_ATEN + +__all__ = ['MetaTensor'] + + +def set_data_ptr(x): + if isinstance(x, torch.Tensor): + if not x.data_ptr(): + data_ptr = uuid.uuid4() + x.data_ptr = lambda: data_ptr + + +@compatibility(is_backward_compatible=False) +class MetaTensor(torch.Tensor): + """ + A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + `fake_device` is the device that `MetaTensor` is supposed to run on. + """ + + _tensor: torch.Tensor + + __slots__ = ['_tensor'] + + @staticmethod + def __new__(cls, elem, fake_device=None): + # Avoid multiple wrapping + if isinstance(elem, MetaTensor): + fake_device = elem.device if fake_device is None else fake_device + elem = elem._tensor + + # The wrapping tensor (MetaTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=fake_device if fake_device is not None else elem.device, + requires_grad=elem.requires_grad) # deceive the frontend for aten selections + r._tensor = elem + # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + r._tensor = r._tensor.to(torch.device('meta')) + # only tensor not on `meta` should be copied to `meta` + set_data_ptr(r._tensor) + return r + + def __repr__(self): + if self.grad_fn: + return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})" + return f"MetaTensor({self._tensor}, fake_device='{self.device}')" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + fake_device = None + + def unwrap(x): + nonlocal fake_device + if isinstance(x, MetaTensor): + fake_device = x.device + x = x._tensor + elif isinstance(x, torch.Tensor): + fake_device = x.device + x = x.to(torch.device('meta')) + return x + + if 'device' in kwargs: + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy + # of the input + if func in ALIAS_ATEN: + out.data_ptr = args[0].data_ptr + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x + + return tree_map(wrap, out) + + def to(self, *args, **kwargs) -> torch.Tensor: + """An extension of `torch.Tensor.to()` to MetaTensor + + Returns: + result (MetaTensor): MetaTensor + + Usage: + >>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100') + >>> tensor.to(torch.uint8) + MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100') + >>> tensor.to(torch.device('cuda:42')) + MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42') + >>> tensor.to('vulkan') + MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan') + """ + # this imitates c++ function in the way of @overload + device = None + for arg in args: + if isinstance(arg, str) or isinstance(arg, _device): + device = arg + if 'device' in kwargs: + device = kwargs['device'] + result = super().to(*args, **kwargs) + if device is not None: + result = MetaTensor(result, fake_device=device) + return result + + def cpu(self, *args, **kwargs): + if self.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) + + def cuda(self, *args, **kwargs): + if self.device.type == 'cuda': + return self.to(*args, **kwargs) + return self.to(*args, device='cuda', **kwargs) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..06272c48f852c8fd14900b0ff832688f300d18a8 --- /dev/null +++ b/colossalai/fx/proxy.py @@ -0,0 +1,127 @@ +import operator +import torch +from torch.fx.proxy import Proxy, Attribute +from typing import List, Union, Any +from colossalai.fx.tracer.meta_patch import meta_patched_function + +__all__ = ['ColoProxy'] + + +class ColoProxy(Proxy): + """ + ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy + cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements. + + Example:: + + proxy = tracer.create_proxy(...) + proxy.meta_data = torch.empty(4, 2, device='meta') + print(len(proxy)) # expect output 4 + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.node._meta_data = None + + @property + def meta_data(self): + return self.node._meta_data + + @meta_data.setter + def meta_data(self, data: Any): + self.node._meta_data = data + + @property + def has_meta_data(self): + return self._meta_data is not None + + def _assert_meta_data_is_tensor(self): + assert torch.is_tensor( + self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' + + def _assert_has_meta_data(self): + assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' + + def __len__(self): + self._assert_has_meta_data() + return len(self.meta_data) + + def __int__(self): + self._assert_has_meta_data() + return int(self.meta_data) + + def __float__(self): + self._assert_has_meta_data() + return float(self.meta_data) + + def __bool__(self): + self._assert_has_meta_data() + return self.meta_data + + def __getattr__(self, k): + + return ColoAttribute(self, k) + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + +def extract_meta(*args, **kwargs): + """ + This function is copied from _tracer_utils.py to avoid circular import issue. + """ + + def _convert(val): + if isinstance(val, ColoProxy): + return val.meta_data + elif isinstance(val, (list, tuple)): + return type(val)([_convert(ele) for ele in val]) + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): + if self._node is None: + proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*(self.root, self.attr)) + meta_out = getattr(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + self._node = proxy.node + + return self._node + + def __call__(self, *args, **kwargs): + proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs) + method = getattr(meta_args[0].__class__, self.attr) + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + elif meta_patched_function.has(method.__name__): + meta_target = meta_patched_function.get(method.__name__) + else: + meta_target = method + meta_out = meta_target(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..590555ce38bf30fab0c575bb69f05af74c8d3802 --- /dev/null +++ b/colossalai/fx/tracer/__init__.py @@ -0,0 +1,5 @@ +from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem + +from ._meta_trace import meta_trace +from ._symbolic_trace import symbolic_trace +from .tracer import ColoTracer diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5abb81d271144ab666049d3ae2868cd9568497 --- /dev/null +++ b/colossalai/fx/tracer/_meta_trace.py @@ -0,0 +1,133 @@ +import torch +from torch.fx import Graph, Node +from torch.utils._pytree import tree_map + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + +def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph: + """Trace forward and backward graph with MetaTensor + + Args: + module (torch.nn.Module): The target module for tracing. + + Returns: + graph (torch.fx.Graph): The computation graph. + + Usage: + >>> import torchvision.models as tm + >>> model = tm.alexnet() + >>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224)) + >>> graph.print_tabular() + """ + graph = Graph() + namespace = graph._graph_namespace + + class MetaProxy(torch.Tensor): + """ + A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + """ + + _tensor: torch.Tensor + _node: Node + + __slots__ = ['_tensor', '_node'] + + @staticmethod + def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): + r = torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + device=fake_device if fake_device is not None else tensor.device, + requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + r._tensor = tensor + if placeholder: + if name is None: + name = 'input' + r._node = graph.create_node('placeholder', + 'placeholder', (graph._root,), + name=namespace.create_name(name, tensor)) + # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + r._tensor = r._tensor.to(torch.device('meta')) + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + def unwrap(x): + nonlocal fake_device + if isinstance(x, MetaProxy): + fake_device = x.device + x = x._tensor + # assert not isinstance(x, MetaProxy) + elif isinstance(x, torch.Tensor): + fake_device = x.device + x = x.to(torch.device('meta')) + return x + + def get_node(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): + x = MetaProxy(x, placeholder=True, name='weight') + return x if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = graph.create_node('call_function', func, args_node, kwargs_node) + + if 'device' in kwargs: + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return MetaProxy( + x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + + def set_node(x): + x._node = node + + out = tree_map(wrap, out) + tree_map(set_node, out) + + return out + + def wrap(x): + return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + out = module(*args, **kwargs) + + for tensor in normalize_tuple(out): + if is_autogradable(tensor) and tensor.requires_grad: + grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( + tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) + torch.autograd.backward(tensor, + MetaProxy(grad, fake_device=tensor.device, placeholder=True), + retain_graph=True) + return graph diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..bff2f6a10fa6a54b4f52c6622a526b2bf6c82a41 --- /dev/null +++ b/colossalai/fx/tracer/_symbolic_trace.py @@ -0,0 +1,54 @@ +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from colossalai.fx import ColoGraphModule +from colossalai.fx._compatibility import compatibility + +from .tracer import ColoTracer + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, +) -> ColoGraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule`` + constructed by recording operations seen while tracing through ``root``. + + With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using + ``meta_args`` only, the tracing can be done ahead of time. + + Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the + argument's values. + + Uses: + >>> model = ... + + # if this works + >>> gm = symbolic_trace(model, concrete_args=concrete_args) + + # else try this + >>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')}) + + Args: + root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing. + meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``. + Defaults to None. + + Returns: + ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``. + + Warnings: + This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. + + """ + graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec49a90a133b7a8db374844a1d343fa02ef0fd0 --- /dev/null +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -0,0 +1,50 @@ +from typing import List, Union, Any +from ..proxy import ColoProxy, ColoAttribute +import torch +from .meta_patch import meta_patched_function, meta_patched_module + +__all__ = ['is_element_in_list', 'extract_meta'] + + +def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): + if isinstance(elements, (tuple, list, set)): + for ele in elements: + if ele not in list_: + return False, ele + else: + if elements not in list_: + return False, elements + + return True, None + + +def extract_meta(*args, **kwargs): + + def _convert(val): + if isinstance(val, ColoProxy): + return val.meta_data + elif isinstance(val, (list, tuple)): + return type(val)([_convert(ele) for ele in val]) + + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + +def compute_meta_data_for_functions_proxy(target, args, kwargs): + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + + return meta_out diff --git a/colossalai/fx/tracer/bias_addition_patch/__init__.py b/colossalai/fx/tracer/bias_addition_patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e724d6a22fa84ecd954a59ebc6eb9b8daa00a035 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/__init__.py @@ -0,0 +1,2 @@ +from .patched_bias_addition_function import * +from .patched_bias_addition_module import * diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..071bde4a5293e391d618973f803360deb4cd1b4c --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py @@ -0,0 +1,4 @@ +from .addbmm import Addbmm +from .addmm import Addmm +from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict +from .linear import Linear diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py new file mode 100644 index 0000000000000000000000000000000000000000..859a19bf6241bbbf4061e0f7564975682527b8c2 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -0,0 +1,75 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addbmm) +@bias_addition_function.register(torch.addbmm) +class Addbmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.bmm + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # torch.bmm does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def insert_sum_node(self, input_proxy, sum_dims=0): + ''' + This method is used to sum the input_proxy through the sum_dims. + ''' + node_kind = 'call_function' + node_target = torch.sum + node_args = (input_proxy, sum_dims) + node_kwargs = {} + sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return sum_proxy + + def generate(self): + # The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2)) + + # doing the non-bias computation(temp_0 = torch.bmm(b1, b2)) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2]) + + # doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0)) + sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + # doing the multiplication with beta if it exists(temp_2 = beta * input) + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) + alpha_proxy = self.create_mul_node(alpha, sum_proxy) + else: + alpha_proxy = sum_proxy + + # doing the addition(temp_4 = temp_2 + temp_3) + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7d8d07aac941028d7c682043b5af2bcdf2537a --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -0,0 +1,60 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addmm) +@bias_addition_function.register(torch.addmm) +class Addmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def transpose_other_operand_for_linear(self, other_proxy): + ''' + This method is used to transpose the other operand for linear function. + For example: + input = torch.rand(3, 4) + m1 = torch.rand(3, 5) + m2 = torch.rand(5, 4) + original_output = torch.addmm(input, m1, m2) + # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 + # before we call the linear function. + new_output = torch.linear(m1, m2.transpose(0, 1)) + input + ''' + node_kind = 'call_function' + node_target = torch.transpose + node_args = (other_proxy, 0, 1) + node_kwargs = {} + transpose_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return transpose_proxy + + def generate(self): + transpose_proxy = self.transpose_other_operand_for_linear(self.args[2]) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) + else: + alpha_proxy = non_bias_linear_func_proxy + + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3786332c08d9a3320ce4c7bee8221dd3d10abd --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -0,0 +1,115 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionFunc(ABC): + """ + This class is used to construct the restructure computation graph for + call_func node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + + @abstractmethod + def extract_kwargs_from_origin_func(self): + """ + This method is used to extract the kwargs for further graph transform. + + For example: + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need + to insert two more operator.mul nodes for the computation graph to compute the + final result. + """ + pass + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_func node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use torch.addmm as an example: + The origin node is: + %addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1}) + Restructured graph is: + %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + """ + pass + + def create_mul_node(self, input_proxy, coefficent): + """ + This method is used to create a coefficent node for the numerical correctness. + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + Therefore, we need to use this method insert two more operator.mul nodes for + the computation graph to compute the final result. + """ + node_kind = 'call_function' + node_target = operator.mul + node_args = ( + input_proxy, + coefficent, + ) + node_kwargs = {} + mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return mul_proxy + + +class LinearBasedBiasFunc(BiasAdditionFunc): + """ + This class is used to construct the restructure computation graph for + call_func node based on F.linear. + """ + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.nn.functional.linear + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # non-bias linear does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + +func_to_func_dict = { + torch.addmm: F.linear, + torch.addbmm: torch.bmm, + F.linear: F.linear, +} + +method_to_func_dict = { + torch.Tensor.addmm: F.linear, + torch.Tensor.addbmm: torch.bmm, +} diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e11ec0a364f1e5ee1445c97ae0c9b054d02bcfa2 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py @@ -0,0 +1,25 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_function.register(F.linear) +class Linear(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + assert 'bias' in self.kwargs + kwargs = {} + if 'bias' in self.kwargs: + kwargs['bias'] = self.kwargs['bias'] + return kwargs + + def generate(self): + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) + kwargs = self.extract_kwargs_from_origin_func() + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3823bb3e2a20e3963cc451459ec963a36eb1139 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py @@ -0,0 +1,3 @@ +from .bias_addition_module import * +from .conv import * +from .linear import * diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py new file mode 100644 index 0000000000000000000000000000000000000000..85f1553e304c9c45b2b8f1373e76e7023d452d47 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py @@ -0,0 +1,111 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionModule(ABC): + """ + This class is used to construct the restructure computation graph for + call_module node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + self.weight_proxy = self._create_weight_proxy() + self.bias_proxy = self._create_bias_proxy() + + def _create_weight_proxy(self): + """ + Create weight proxy, the node created by this proxy contains module weight. + + Note: this function will be invoked during module initializing, + you should never call this function. + """ + weight_node_kind = 'get_attr' + weight_node_target = self.target + '.weight' + weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) + return weight_proxy + + def _create_bias_proxy(self): + """ + Create bias proxy, the node created by this proxy contains module bias. + + Note: this function will be invoked during module initializing, + you should never call this function. + """ + bias_node_kind = 'get_attr' + bias_node_target = self.target + '.bias' + bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) + return bias_proxy + + @abstractmethod + def extract_kwargs_from_mod(self): + """ + This method is used to extract the kwargs for non-bias computation. + + For example: + The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are + considered during module initilizing. However, we need to consider those attributes as kwargs + in F.conv2d. + """ + pass + + def create_non_bias_func_proxy(self, input_proxy=None): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + node_kind = 'call_function' + node_target = self.substitute_func + if input_proxy is None: + input_proxy = self.args[0] + node_args = (input_proxy, self.weight_proxy) + node_kwargs = self.extract_kwargs_from_mod() + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_module node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use Conv2d module as an example: + The origin node is: + %conv: call_module[target=conv](args = (%x,), kwargs = {}) + Restructured graph is: + %conv_weight : [#users=1] = get_attr[target=conv.weight] + %conv_bias : [#users=1] = get_attr[target=conv.bias] + %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + """ + pass + + +module_to_func_dict = { + torch.nn.Linear: F.linear, + torch.nn.Conv1d: F.conv1d, + torch.nn.Conv2d: F.conv2d, + torch.nn.Conv3d: F.conv3d, +} diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6c82a74f57d213ba3d8b68863053ddd1aabf9d --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple + +from ...registry import bias_addition_module +from .bias_addition_module import BiasAdditionModule + + +@bias_addition_module.register(torch.nn.Conv1d) +@bias_addition_module.register(torch.nn.Conv2d) +@bias_addition_module.register(torch.nn.Conv3d) +class BiasAdditionConv(BiasAdditionModule): + + def extract_kwargs_from_mod(self): + root = self.tracer.root + conv_module = root.get_submodule(self.target) + kwarg_attributes = ['groups', 'dilation', 'stride'] + non_bias_kwargs = {} + for attr_name in kwarg_attributes: + if hasattr(conv_module, attr_name): + non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) + if conv_module.padding_mode != "zeros": + #TODO: non zeros mode requires some extra processing for input + conv_type = type(conv_module) + if conv_type == "torch.nn.Conv1d": + padding_element = _single(0) + elif conv_type == "torch.nn.Conv2d": + padding_element = _pair(0) + elif conv_type == "torch.nn.Conv3d": + padding_element = _triple(0) + non_bias_kwargs['padding'] = padding_element + else: + non_bias_kwargs['padding'] = getattr(conv_module, 'padding') + + return non_bias_kwargs + + def create_bias_reshape_proxy(self, dimensions): + """ + This method is used to reshape the bias node in order to make bias and + output of non-bias convolution broadcastable. + """ + bias_shape = [1] * (dimensions - 1) + bias_shape[0] = -1 + bias_reshape_node_kind = 'call_method' + bias_reshape_node_target = 'view' + bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) + bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, + bias_reshape_node_args, {}) + return bias_reshape_proxy + + def generate(self): + non_bias_conv_func_proxy = self.create_non_bias_func_proxy() + output_dims = non_bias_conv_func_proxy.meta_data.dim() + bias_reshape_proxy = self.create_bias_reshape_proxy(output_dims) + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_conv_func_proxy, bias_reshape_proxy) + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f7b6ddab401a637aa2b43b4dd8d2ce9193266e --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_module +from .bias_addition_module import BiasAdditionModule + + +@bias_addition_module.register(torch.nn.Linear) +class BiasAdditionLinear(BiasAdditionModule): + + def extract_kwargs_from_mod(self): + return {} + + def generate(self): + non_bias_linear_func_proxy = self.create_non_bias_func_proxy() + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, self.bias_proxy) + return bias_addition_proxy diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..66e7149120f31204a08d63ba268475c6d2014de4 --- /dev/null +++ b/colossalai/fx/tracer/experimental.py @@ -0,0 +1,394 @@ +import enum +import functools +import inspect +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from torch.fx import Graph, Node, Proxy, Tracer +from torch.utils._pytree import tree_map + +from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +Target = Union[Callable[..., Any], str] +Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + 'Node',]] +_CScriptMethod = ['add', 'mul', 'sub', 'div'] +_TorchNewMethod = [ + "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", + "finfo" +] +_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] + + +def _truncate_suffix(s: str): + import re + return re.sub(r'_\d+$', '', s) + + +def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): + if isinstance(elements, (tuple, list, set)): + for ele in elements: + if ele not in list_: + return False, ele + else: + if elements not in list_: + return False, elements + + return True, None + + +def default_device(): + return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + +@compatibility(is_backward_compatible=False) +class ColoProxy(Proxy): + + def __init__(self, *args, data=None, **kwargs): + super().__init__(*args, **kwargs) + self._data = data + + @property + def data(self): + return self._data + + @data.setter + def data(self, args): + wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x + self._data = tree_map(wrap_fn, args) + + @classmethod + def __torch_function__(cls, orig_method, types, args=(), kwargs=None): + proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) + unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p + kwargs = {} if kwargs is None else kwargs + if proxy.data is None: + proxy.data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + return proxy + + @classmethod + def from_torch_proxy(cls, proxy: Proxy): + return cls(proxy.node, proxy.tracer) + + def __repr__(self): + return f"ColoProxy({self.node.name}, data={self.data})" + + def __len__(self): + return len(self.data) + + def __int__(self): + return int(self.data) + + def __index__(self): + try: + return int(self.data) + except: + return torch.zeros(self.data.shape, dtype=torch.bool).numpy().__index__() + + def __float__(self): + return float(self.data) + + def __bool__(self): + return self.data + + def __getattr__(self, k): + return ColoAttribute(self, k, getattr(self._data, k, None)) + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + def __isinstancecheck__(self, type): + return isinstance(self.data, type) + + @property + def shape(self): + return self.data.shape + + @property + def ndim(self): + return self.data.ndim + + @property + def device(self): + proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) + proxy.data = self.data.device + return proxy + + @property + def dtype(self): + proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) + proxy.data = self.data.dtype + return proxy + + def to(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) + + def cpu(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) + + def cuda(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) + + +@compatibility(is_backward_compatible=False) +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str, data=None): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._data = data + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + def __repr__(self): + return f"ColoAttribute({self.node.name}, attr={self.attr})" + + +@compatibility(is_backward_compatible=False) +class ColoTracer(Tracer): + + def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._disable_module_getattr = False + self.proxy_buffer_attributes = True + + def proxy(self, node: Node) -> 'ColoProxy': + return ColoProxy(node, self) + + def create_proxy(self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p + if kind == 'placeholder': + proxy.data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( + _truncate_suffix(target), None) + elif kind == 'get_attr': + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + proxy.data = attr_itr + finally: + self._disable_module_getattr = False + elif kind == 'call_function': + proxy.data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + self._disable_module_getattr = True + try: + if target == '__call__': + proxy.data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + proxy._data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + finally: + self._disable_module_getattr = False + elif kind == 'call_module': + mod = self.root.get_submodule(target) + unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p + self._disable_module_getattr = True + try: + proxy.data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + finally: + self._disable_module_getattr = True + return proxy + + def trace(self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + # get non concrete arg names + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + + def _check_arg_name_valid(names): + success, element = is_element_in_list(names, sig_names) + if not success: + raise KeyError( + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + with _TorchTensorOverride(self): + self.graph = super().trace(root, concrete_args=concrete_args) + self.graph.lint() + return self.graph + + def _post_check(self, non_concrete_arg_names: Set[str]): + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + self.graph.lint() + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: + kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else + lambda node: ColoProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + return attr_val + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, +) -> ColoGraphModule: + if is_compatible_with_meta(): + if meta_args is not None: + root.to(default_device()) + wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x + graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)) + root.cpu() + else: + graph = Tracer().trace(root, concrete_args=concrete_args) + else: + from .tracer import ColoTracer as OrigColoTracer + graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) + + +@compatibility(is_backward_compatible=False) +class _TorchTensorOverride(object): + + def __init__(self, tracer: Tracer): + self.overrides = {} + self.tracer = tracer + + def __enter__(self): + + def wrap_tensor_method(target): + + @functools.wraps(target) + def wrapper(*args, **kwargs): + is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( + isinstance(p, ColoProxy) for p in kwargs.values()) + if is_proxy: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + self.tracer._disable_module_getattr = True + try: + proxy = self.tracer.create_proxy('call_function', target, args, kwargs) + finally: + self.tracer._disable_module_getattr = False + return proxy + else: + return target(*args, **kwargs) + + return wrapper, target + + self.overrides = { + target: wrap_tensor_method(getattr(torch, target)) + for target in _TorchNewMethod + if callable(getattr(torch, target)) + } + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, wrapper) + + def __exit__(self, exc_type, exc_val, exc_tb): + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, orig) diff --git a/colossalai/fx/tracer/meta_patch/__init__.py b/colossalai/fx/tracer/meta_patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..192aef7a4ba0388a817c8146b55c41311faa577a --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/__init__.py @@ -0,0 +1,2 @@ +from .patched_function import * +from .patched_module import * diff --git a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00fdf6f5c328e4e92a3911094325ada2182c5f9 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py @@ -0,0 +1,6 @@ +from .activation_function import * +from .arithmetic import * +from .convolution import * +from .embedding import * +from .normalization import * +from .torch_ops import * diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..12c42514895e61777f0dbb1c348206af7d0ac5ec --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -0,0 +1,8 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.relu) +def torch_nn_func_relu(input, inplace=False): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..042b92c5847a4ed0d78d4acf068a0a5030fa7089 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -0,0 +1,95 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.matmul) +@meta_patched_function.register('matmul') # for built-in op @ +def torch_matmul(input, other, *, out=None): + # copied from huggingface.utils.fx + d1 = input.dim() + d2 = other.dim() + shape = None + if d1 == 1 and d2 == 1: + shape = None + elif d1 == 2 and d2 == 2: + shape = (input.size(0), other.size(1)) + elif d1 == 1 and d2 == 2: + shape = (other.size(1),) + elif d1 == 2 and d2 == 1: + shape = (input.size(0),) + else: + max_length = max(input.dim(), other.dim()) + shape1 = list(input.shape) + shape2 = list(other.shape) + if d1 == 1: + shape1 = [1] + shape1 + if d2 == 1: + shape2.append(1) + shape1 = [-1] * (max_length - d1) + list(input.shape) + shape2 = [-1] * (max_length - d2) + list(other.shape) + shape = [] + for i in range(max_length): + shape.append(max(shape1[i], shape2[i])) + shape[-2] = shape1[-2] + shape[-1] = shape2[-1] + if d1 == 1: + shape.pop(-2) + if d2 == 1: + shape.pop(-1) + if shape is None: + return torch.tensor(0.0, device="meta") + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.abs) +def torch_abs(input, *, out=None): + assert out is None, 'out is not supported yet' + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + batch_size, n, m = input.shape + _, _, p = mat2.shape + return torch.empty(batch_size, n, p, device="meta") + + +@meta_patched_function.register(torch.nn.functional.linear) +def torch_linear(input, mat2, bias=None, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + output_shape = list(input.shape) + output_feature = list(mat2.shape)[0] + output_shape[-1] = output_feature + return torch.empty(*output_shape, device="meta") + + +@meta_patched_function.register(torch.addbmm) +@meta_patched_function.register(torch.Tensor.addbmm) +def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + _, n, _ = mat1.shape + _, _, p = mat2.shape + return torch.empty(n, p, device="meta") + + +@meta_patched_function.register(torch.addmm) +@meta_patched_function.register(torch.Tensor.addmm) +def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + n, _ = mat1.shape + _, p = mat2.shape + return torch.empty(n, p, device="meta") + + +@meta_patched_function.register(torch.var_mean) +def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): + assert out is None, 'saving to out is not supported yet' + var = torch.empty(1).squeeze(0).to('meta') + mean = torch.empty(1).squeeze(0).to('meta') + return var, mean diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..8500e5c82508195ca3ec8ebc99a33ad2a1b946ad --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -0,0 +1,180 @@ +import collections +import math +from itertools import repeat + +import torch + +from ...registry import meta_patched_function + + +def _ntuple(n, name="parse"): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") + + +def _extract_kwargs(kwargs): + if 'stride' in kwargs: + stride = kwargs['stride'] + else: + stride = 1 + # TODO: process str type padding + if 'padding' in kwargs: + padding = kwargs['padding'] + else: + padding = 0 + if 'dilation' in kwargs: + dilation = kwargs['dilation'] + else: + dilation = 1 + if 'output_padding' in kwargs: + output_padding = kwargs['output_padding'] + else: + output_padding = 0 + + return stride, padding, dilation, output_padding + + +@meta_patched_function.register(torch.nn.functional.conv1d) +def torch_nn_functional_conv1d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + + kernel_size = weight.shape[2:] + l_in = input.shape[-1] + c_out = weight.shape[0] + l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv2d) +def torch_nn_functional_conv2d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + kernel_size = weight.shape[2:] + h_in, w_in = input.shape[-2:] + c_out = weight.shape[0] + h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv3d) +def torch_nn_functional_conv3d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + kernel_size = weight.shape[2:] + d_in, h_in, w_in = input.shape[-3:] + c_out = weight.shape[0] + d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose1d) +def torch_nn_functional_convtranspose1d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + kernel_size = weight.shape[2:] + l_in = input.shape[-1] + c_out = weight.shape[1] + l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose2d) +def torch_nn_functional_convtranspose2d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + kernel_size = weight.shape[2:] + h_in, w_in = input.shape[-2:] + c_out = weight.shape[1] + h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose3d) +def torch_nn_functional_convtranspose3d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + kernel_size = weight.shape[2:] + d_in, h_in, w_in = input.shape[-3:] + c_out = weight.shape[1] + d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + 1) + w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + + output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8d864ea29acd648f7f7097821c248655b0191e --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -0,0 +1,14 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding(input, + weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False): + return torch.empty(*input.shape, weight.shape[-1], device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e7eda6159c88ca3a82320c841360df87540c82 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -0,0 +1,20 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm(input, + running_mean, + running_var, + weight=None, + bias=None, + training=False, + momentum=0.1, + eps=1e-05): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4c171cb1099119de54b70b03097e1781880f2624 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -0,0 +1,60 @@ +import operator + +import torch + +from colossalai.fx.proxy import ColoProxy + +from ...registry import meta_patched_function + + +@meta_patched_function.register(operator.getitem) +def operator_getitem(a, b): + # copied from huggingface.utils.fx + def to_concrete(t): + if isinstance(t, torch.Tensor): + concrete = torch.ones_like(t, device="cpu") + if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: + concrete = concrete.to(torch.int64) + return concrete + return t + + def _slice_convert(slice_obj): + attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step} + new_attrs = _slice_attr_convert(attrs) + attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step']) + return slice(*attr_dict_to_tuple) + + def _slice_attr_convert(attrs): + new_attrs = {} + for key, value in attrs.items(): + if isinstance(value, ColoProxy): + new_attrs[key] = value.meta_data + else: + new_attrs[key] = value + return new_attrs + + if isinstance(b, tuple): + b = list(b) + for index, element in enumerate(b): + if isinstance(element, slice): + b[index] = _slice_convert(element) + b = tuple(b) + elif isinstance(b, slice): + b = _slice_convert(b) + + if isinstance(a, torch.Tensor): + # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) + return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") + + if isinstance(a, ColoProxy): + # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) + return operator.getitem(torch.empty_like(a.meta_data, device="cpu"), b).to("meta") + return operator.getitem(a, b) diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b14ff10ce1377055ed7f3ab3025ee7c05c6a1657 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -0,0 +1,174 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.arange) +def torch_arange(*args, **kwargs): + n = len(args) + step = 1 + if n == 1: + start = 0 + end = args[0] + elif n == 2: + start, end = args + else: + start, end, step = args + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + start = int(end) + if isinstance(step, float): + step = int(step) + step = kwargs.get("step", step) + dtype = kwargs.get("dtype") + return torch.empty((end - start) // step, dtype=dtype, device="meta") + + +@meta_patched_function.register(torch.finfo) +def torch_finfo(*args): + return torch.finfo(*args) + + +@meta_patched_function.register(torch.where) +def torch_where(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +@meta_patched_function.register(torch.Tensor.repeat) +def torch_tensor_repeat(self, *sizes): + shape = list(self.shape) + for i, x in enumerate(sizes): + shape[i] *= x + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.index_select) +def torch_index_select(input, dim, index, *, out=None): + shape = list(input.shape) + shape[dim] = len(index) + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.index_select) +def torch_tensor_index_select(self, dim, index): + return torch_index_select(self, dim, index) + + +@meta_patched_function.register(torch.squeeze) +def torch_squeeze(input, dim=None): + shape = list(input.shape) + if dim is not None: + if dim < 0: + dim = input.dim() + dim + if shape[dim] == 1: + shape.pop(dim) + else: + new_shape = [] + for dim_value in shape: + if dim_value == 1: + continue + new_shape.append(dim_value) + shape = new_shape + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.squeeze) +def torch_tensor_squeeze(self, dim=None): + return torch_squeeze(self, dim) + + +@meta_patched_function.register(torch.unsqueeze) +def torch_unsqueeze(input, dim): + shape = list(input.shape) + if dim < 0: + dim = input.dim() + 1 + dim + shape.insert(dim, 1) + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.unsqueeze) +def torch_tensor_unsqueeze(self, dim): + return torch_unsqueeze(self, dim) + + +@meta_patched_function.register(torch.cat) +def torch_cat(tensors, dim=None, axis=None, *, out=None): + if dim is None and axis is None: + dim = 0 + if dim is None and axis is not None: + dim = axis + if dim < 0: + dim = tensors[0].dim() + dim + shapes = [t.shape for t in tensors] + shape = list(shapes[0]) + concatenated_dim = sum(shape[dim] for shape in shapes) + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + return torch.empty(final_shape, device="meta") + + +@meta_patched_function.register(torch.repeat_interleave) +def torch_repeat_interleave(input, repeats, dim=None, output_size=None): + assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ + "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" + + shape = list(input.shape) if dim is not None else [input.numel()] + dim = dim if dim is not None else 0 + dim = input.dim() + dim if dim < 0 else dim + + if isinstance(repeats, int): + shape[dim] = shape[dim] * repeats + elif isinstance(repeats, torch.Tensor): + shape[dim] = repeats.sum() + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.repeat_interleave) +def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None): + return torch_repeat_interleave(self, repeats, dim, output_size) + + +@meta_patched_function.register(torch.roll) +def torch_roll(input, shifts, dims=None): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.full) +def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + assert out is None, 'assigning result to out is not supported yet' + return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) + + +@meta_patched_function.register(torch.max) +def torch_max(input, dim=None, keepdim=False, *, out=None): + assert out is None, 'assigning value to out is not supported yet' + if dim is not None: + if isinstance(dim, int): + shape = list(input.shape) + shape.pop(dim) + if keepdim: + shape.insert(dim, 1) + return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, + device='meta', + dtype=input.dtype) + elif isinstance(dim, torch.Tensor): + # when dim is a 0D or 1D tensor, it will maintain the same shape + num_dims = dim.dim() + if num_dims in [0, 1]: + return torch.empty_like(input, device='meta') + else: + raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") + else: + return torch.empty([], device='meta', dtype=input.dtype) + + +@meta_patched_function.register(torch.Tensor.cpu) +def torch_tensor_cpu(input): + return input.clone() + + +@meta_patched_function.register(torch.Tensor.cuda) +def torch_tensor_cuda(input, *args, **kwargs): + return input.clone() diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e28e52585fffc193473a7c8270c103919cc63e0d --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -0,0 +1,7 @@ +from .activation_function import * +from .convolution import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .rnn import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..d03da6588c1cbf56403dccc5989f4a4987b7e2a9 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -0,0 +1,13 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.ReLU) +@meta_patched_module.register(torch.nn.Sigmoid) +@meta_patched_module.register(torch.nn.GELU) +@meta_patched_module.register(torch.nn.Tanh) +@meta_patched_module.register(torch.nn.ReLU6) +@meta_patched_module.register(torch.nn.PReLU) +def torch_nn_non_linear_act(self, input): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9f3487aac9f31ff799e3245132c40d643de64f --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -0,0 +1,113 @@ +import math + +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ConvTranspose1d) +def torch_nn_convtranspose1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ConvTranspose2d) +def torch_nn_convtranspose2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ConvTranspose3d) +def torch_nn_convtranspose3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..999e33b17c1c7b442d2a6db73f957be4413f1fa1 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -0,0 +1,9 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Embedding) +def torch_nn_embedding(self, input): + result_shape = input.shape + (self.embedding_dim,) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..56f13bf97532e26770a0be7a4226ee69a2124ee5 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -0,0 +1,10 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Linear) +def torch_nn_linear(self, input): + last_dim = input.shape[-1] + assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' + return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c21ff64cf3dec9baf357771fa0d15b341b413ac1 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -0,0 +1,31 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.LayerNorm) +@meta_patched_module.register(torch.nn.GroupNorm) +@meta_patched_module.register(torch.nn.BatchNorm1d) +@meta_patched_module.register(torch.nn.BatchNorm2d) +@meta_patched_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self, input): + # check shape + if isinstance(self, torch.nn.BatchNorm1d): + assert input.dim() in [2, 3] + elif isinstance(self, torch.nn.BatchNorm2d): + assert input.dim() == 4 + elif isinstance(self, torch.nn.BatchNorm3d): + assert input.dim() == 5 + + # normalization maintain the same shape as the input + return input.clone() + + +try: + import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except (ImportError, AttributeError): + pass diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce23fbf7ac9368f4ec8496b252494e779fb5015 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -0,0 +1,202 @@ +import math + +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.AvgPool1d) +def torch_nn_avgpool1d(self, input): + num_dim = input.dim() + assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + + l_in = input.shape[-1] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 1 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + + result_shape = tuple(input.shape[:-1]) + (l_out,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AvgPool2d) +def torch_nn_avgpool2d(self, input): + num_dim = input.dim() + assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + + h_in, w_in = input.shape[-2:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 2 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) + + result_shape = tuple(input.shape[:-2]) + ( + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AvgPool3d) +def torch_nn_avgpool3d(self, input): + num_dim = input.dim() + assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + + d_in, h_in, w_in = input.shape[-3:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 3 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + d_out = math.floor((d_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1) + + result_shape = tuple(input.shape[:-3]) + ( + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool1d) +def torch_nn_maxpool1d(self, input): + num_dim = input.dim() + assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + + l_in = input.shape[-1] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 1 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + + result_shape = tuple(input.shape[:-1]) + (l_out,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool2d) +def torch_nn_maxpool2d(self, input): + num_dim = input.dim() + assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + + h_in, w_in = input.shape[-2:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 2 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + + result_shape = tuple(input.shape[:-2]) + ( + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool3d) +def torch_nn_maxpool3d(self, input): + num_dim = input.dim() + assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + + d_in, h_in, w_in = input.shape[-3:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 3 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) + + result_shape = tuple(input.shape[:-3]) + ( + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d) +def torch_nn_adapative_pooling_1d(self, input): + assert input.dim() in [2, 3] + if isinstance(self.output_size, int): + output_size = (self.output_size,) + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-1]) + output_size + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d) +def torch_nn_adapative_pooling_2d(self, input): + assert input.dim() in [3, 4] + if isinstance(self.output_size, int): + output_size = (self.output_size,) * 2 + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-2]) + output_size + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d) +def torch_nn_adapative_pooling_3d(self, input): + assert input.dim() in [4, 5] + if isinstance(self.output_size, int): + output_size = (self.output_size,) * 3 + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-3]) + output_size + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ee15ca34162e83612eb179e0cff066d9f06faf36 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -0,0 +1,16 @@ +from typing import Optional + +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.GRU) +@meta_patched_module.register(torch.nn.RNN) +def torch_nn_rnn(self, input, hx): + assert input.shape[ + -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' + assert hx.shape[ + -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' + d = 2 if self.bidirectional else 1 + return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..12fc6de73d4435dea8ec58fa50b93a6070fd6254 --- /dev/null +++ b/colossalai/fx/tracer/registry.py @@ -0,0 +1,28 @@ +class PatchRegistry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') +meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') +bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') +bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') +bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6f9c23bf643f14fd433e024321e54b775e182f --- /dev/null +++ b/colossalai/fx/tracer/tracer.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python +""" +tracer.py: + Implemented a tracer which supports control flow and user-defined meta arguments. + The implementation is partly inspired HuggingFace's fx tracer +""" +import enum +import functools +import inspect +import operator +from contextlib import contextmanager +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.fx import Node, Tracer +from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods +from torch.fx.proxy import ParameterProxy, Proxy + +from ..proxy import ColoProxy +from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list +from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from .registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) + +__all__ = ['ColoTracer'] + + +class TracerType(enum.Enum): + DEFAULT = 1 + META = 2 + + +class ColoTracer(Tracer): + """ + ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. + This tracer is initialized in the same way as the original torch.fx.Tracer. + + Usage:: + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x, y): + x1 = self.linear1(x) + y1 = self.linear2(y) + + if x1.dim() == 2: + return x1 + y1 + else: + return x1 - y1 + + model = Model() + tracer = ColoTracer() + graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) + """ + + def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tracer_type = TracerType.META + self.proxy_cls = ColoProxy + + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + # whether the current tracing occurs within the activation checkpoint functions + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count = 0 + + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"] + + def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: + """ + Create a proxy for different kinds of operations. + """ + + if self.tracer_type == TracerType.DEFAULT: + # since meta_args is not given + # we just fall back to the original torch.fx.Tracer + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + return proxy + + # if graph is traced for auto parallelism module, some extra node will be added during + # graph construction to deal with the compatability between bias addition and all reduce. + + # if no extra manipulation is applied, we just pass the origin arguments to create_proxy function + # to create node on computation graph + origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + # dispatch the arguments generator depending on the kind and target in origin arguments. + args_metas, _ = extract_meta(*args, **kwargs) + handle = None + if kind == "call_function": + if bias_addition_function.has(target): + if target == torch.nn.functional.linear: + if 'bias' in kwargs and kwargs['bias'] is not None: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) + else: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) + elif bias_addition_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute) + + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute) + + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if bias_addition_module.has(mod_type) and mod.bias is not None: + function_to_substitute = module_to_func_dict[mod_type] + handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute) + finally: + self._disable_module_getattr = False + + if handle is not None: + return handle.generate() + + # create nodes using patched arguments + proxy = super().create_proxy(*origin_arguments) + proxy: ColoProxy + meta_out = self._meta_data_computing( + kind, + target, + args, + kwargs, + ) + proxy.meta_data = meta_out + + return proxy + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + else: + # return super()._module_getattr(attr, attr_val, parameter_proxy_cache) + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else + lambda node: ParameterProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), + parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + module_qualified_name = self.path_of_module(m) + + # a leaf module is the torch.nn.Module subclasses starting with `torch.nn` + # which means customized modules are not leaf module by default + # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, + # we should treat it as leaf module as well + if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): + return self.create_proxy('call_module', module_qualified_name, args, kwargs) + else: + return forward(*args, **kwargs) + + def proxy(self, node) -> Proxy: + """ + Returns a ColoProxy object. + """ + return self.proxy_cls(node, self) + + def _configure_tracer_type(self, tracer_type: TracerType): + if tracer_type == TracerType.DEFAULT: + self.proxy_cls = Proxy + self.tracer_type = TracerType.DEFAULT + elif tracer_type == TracerType.META: + self.proxy_cls = ColoProxy + self.tracer_type = TracerType.META + else: + raise ValueError(f"Unrecognised tracer type {tracer_type}") + + def _meta_data_computing(self, kind, target, args, kwargs): + + if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: + meta_out = self.meta_args[target] + return meta_out + + if target in self.orig_torch_tensor_methods: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + if kind == "call_function": + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + finally: + self._disable_module_getattr = False + else: + return None + + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + + return meta_out + + def trace(self, + root: nn.Module, + concrete_args: Optional[Dict[str, Tensor]] = None, + meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: + """ + Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. + + Args: + root (nn.Module): a `nn.Module` object to trace the computation graph + meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph. + These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors. + concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies. + """ + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + if len(meta_args) == 0: + self._configure_tracer_type(TracerType.DEFAULT) + else: + self._configure_tracer_type(TracerType.META) + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + # get non concrete arg names + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + + def _check_arg_name_valid(names): + success, element = is_element_in_list(names, sig_names) + if not success: + raise KeyError( + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + # assign as attributed for late reference + def _check_kwargs(kwargs, should_be_meta: bool): + for k, v in kwargs.items(): + if not should_be_meta: + assert not torch.is_tensor(v) or not v.is_meta, \ + f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' + else: + assert v.is_meta == should_be_meta, \ + f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + + _check_kwargs(concrete_args, should_be_meta=False) + _check_kwargs(meta_args, should_be_meta=True) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + self.patched_torch_tensor_methods = {} + if self.tracer_type == TracerType.META: + # wrap the torch tensor constructing methods so that they are captured in the graph + self.patched_torch_tensor_methods = { + target: wrap_tensor_constructor_method(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH + } + + # patch these methods to replace their original use + for name, (wrapper, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, wrapper) + + # cache these methods so that we can detect whether a method call + # should be patched during tracing + self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()] + + try: + # to track the usage of torch.utils.checkpoint + with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt): + self.graph = super().trace(root, concrete_args=concrete_args) + + finally: + # recover the patched methods + for name, (_, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, orig) + + if self.tracer_type == TracerType.DEFAULT: + return self.graph + + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + + return self.graph + + @contextmanager + def trace_activation_checkpoint(self, enabled: bool): + if enabled: + orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction + + class PatchedCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # signal that the current tracing occurs within activaton checkpoint part + self.inside_torch_checkpoint_func = True + out = run_function(*args) + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count += 1 + return out + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError( + "We do not implement the backward pass as we only trace the forward pass.") + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction + yield + + if enabled: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func + + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + + if self.inside_torch_checkpoint_func: + # annotate the activation checkpoint module + node.meta['activation_checkpoint'] = self.act_ckpt_region_count + return node + + +def wrap_tensor_constructor_method(target): + + def look_for_proxy(*args, **kwargs): + # find in pos vars + for arg in args: + if isinstance(arg, Proxy): + return arg + if isinstance(arg, (tuple, list)): + return look_for_proxy(*arg) + + # find in keyword vars + for k, v in kwargs.items(): + if isinstance(v, Proxy): + return v + if isinstance(v, (tuple, list)): + return look_for_proxy(*v) + return None + + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = look_for_proxy(*args, **kwargs) + + if proxy is not None: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs) + if not isinstance(colo_proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + colo_proxy = ColoProxy(proxy.node) + colo_proxy.meta_data = meta_out + return colo_proxy + else: + # this is called directly when the inputs do not contain proxy + # e.g. torch.ones(4) where the input is static + return target(*args, **kwargs) + + return wrapper, target + + +# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__, +# and add meta_data attribute to the created proxy. +for method in magic_methods: + + def _scope(method): + + def impl(*args, **kwargs): + + tracer = args[0].tracer + target = getattr(operator, method) + proxy = tracer.create_proxy('call_function', target, args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(ColoProxy, as_magic, impl) + + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(ColoProxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a44ebb1ef2fb3e1df7c1217ab91a48c24cdae --- /dev/null +++ b/colossalai/gemini/__init__.py @@ -0,0 +1,9 @@ +from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration +from .gemini_mgr import GeminiManager +from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import TensorPlacementPolicyFactory + +__all__ = [ + 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', + 'search_chunk_configuration' +] diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/gemini/chunk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6914d2dbef4581dbf37610cfc7589a2c5be77406 --- /dev/null +++ b/colossalai/gemini/chunk/__init__.py @@ -0,0 +1,6 @@ +from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState +from .manager import ChunkManager +from .search_utils import classify_params_by_dp_degree, search_chunk_configuration +from .utils import init_chunk_manager + +__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b2741979c19bbc5a8938470d21e3ea98bf8f09 --- /dev/null +++ b/colossalai/gemini/chunk/chunk.py @@ -0,0 +1,576 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.utils import get_current_device + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: + _total_number = 0 + + def __init__(self, + chunk_size: int, + process_group: ColoProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False) -> None: + """ + Chunk: A container owning a piece of contiguous memory space for tensors + Here we use all-gather operation to gather the whole chunk. + Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. + It is designed to make the full use of communication and PCIE bandwidth. + + Args: + chunk_size (int): the number of elements in the chunk + process_group (ColoProcessGroup): the process group of this chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. + The default value is None, which is the current GPU + cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU. + keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory + pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory + """ + self.count_id = Chunk._total_number + Chunk._total_number += 1 + + self.chunk_size = chunk_size + self.utilized_size = 0 + + self.torch_pg = process_group.dp_process_group() + self.pg_size = dist.get_world_size(self.torch_pg) + self.pg_rank = dist.get_rank(self.torch_pg) + + # the chunk size should be divisible by the dp degree + if not keep_gathered: + assert chunk_size % self.pg_size == 0 + self.shard_size = chunk_size // self.pg_size + self.shard_begin = self.shard_size * self.pg_rank + self.shard_end = self.shard_begin + self.shard_size + self.valid_end = self.shard_size + + self.dtype = dtype + device = init_device or get_current_device() + + # chunk_temp is a global chunk, which only exists during building the chunks. + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + + self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA + + # cuda local chunk, which is sharded on GPUs + self.cuda_shard = None + # cpu local chunk, which is sharded on CPUs + self.cpu_shard = None + # is the chunks gathers, which means chunks are duplicated on each process, + # and we should use the cuda_global_chunk. + self.is_gathered = True + + # configure the init device of the shard + # no-offload default: fp16, fp32 -> CUDA + # offload default: fp16, fp32 -> CPU + self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + + self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() + self.shard_mem = self.chunk_mem // self.pg_size + + # each tensor is associated with a TensorInfo to track its meta info + # (state, offset, end) + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + # the total number of tensors in the chunk + self.num_tensors = 0 + + # Record the number of tensors in different states + self.tensor_state_cnter: Dict[TensorState, int] = dict() + for state in TensorState: + self.tensor_state_cnter[state] = 0 + + # If a chunk is kept gathered, + # they are treated the same as that of the parameters in DDP during training. + self.keep_gathered = keep_gathered + if self.keep_gathered: + pin_memory = False # since this chunk is gathered, it doesn't need to pin + + # if pin_memory is True, we allocate a piece of CPU pin-memory + # for it all the time + self.pin_memory = pin_memory + + # we introduce the paired chunk here + # it refers to another chunk having the same parameters + # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk + self.paired_chunk = None + # if this chunk is synchronized with the optimizer, the flag is True + self.optim_sync_flag = True + # if the cpu_shard has been visited during the training step, the flag is True + self.cpu_vis_flag = False + + # whether to record l2 norm for the gradient clipping calculation + self.l2_norm_flag = False + self.l2_norm = None + + @property + def memory_usage(self) -> Dict[str, int]: + cuda_memory = 0 + cpu_memory = 0 + + if self.chunk_temp is not None: + # this chunk is not closed + if self.chunk_temp.device.type == 'cuda': + cuda_memory += self.chunk_mem + else: + cpu_memory += self.chunk_mem + else: + if self.is_gathered: + cuda_memory += self.chunk_mem + if self.cuda_shard is not None: + cuda_memory += self.shard_mem + if self.cpu_shard is not None: + cpu_memory += self.shard_mem + + return dict(cuda=cuda_memory, cpu=cpu_memory) + + @property + def device_type(self) -> str: + if self.chunk_temp is not None: + return self.chunk_temp.device.type + else: + if self.is_gathered: + return 'cuda' + elif self.cuda_shard is not None: + return 'cuda' + else: + return 'cpu' + + @property + def payload(self) -> torch.Tensor: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.cuda_global_chunk + elif self.cuda_shard is not None: + return self.cuda_shard + else: + return self.cpu_shard + + @property + def payload_mem(self) -> int: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_mem + else: + return self.shard_mem + + @property + def can_move(self) -> bool: + return not self.is_gathered + + @property + def can_release(self) -> bool: + if self.keep_gathered: + return False + else: + return self.tensor_state_cnter[TensorState.HOLD] + \ + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors + + @property + def can_reduce(self): + return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors + + @property + def has_inf_or_nan(self) -> bool: + """Check if the chunk has inf or nan values on CUDA. + """ + if self.is_gathered: + valid_tensor = self.cuda_global_chunk[:self.utilized_size] + else: + assert self.cuda_shard is not None # only check on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + + def set_l2_norm(self) -> None: + """Record l2 norm of this chunks on CUDA. + """ + assert self.l2_norm is None, "you are calculating the l2 norm twice" + if self.is_gathered: + valid_tensor = self.cuda_global_chunk[:self.utilized_size] + else: + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + chunk_l2_norm = valid_tensor.data.float().norm(2) + self.l2_norm = chunk_l2_norm.item()**2 + + def append_tensor(self, tensor: torch.Tensor): + """Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + # sanity check + assert self.chunk_temp is not None + assert tensor.dtype == self.dtype + + new_utilized_size = self.utilized_size + tensor.numel() + # raise exception when the chunk size is exceeded + if new_utilized_size > self.chunk_size: + raise ChunkFullError + + self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + + # record all the information about the tensor + self.num_tensors += 1 + tensor_state = TensorState.HOLD + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.tensor_state_cnter[tensor_state] += 1 + self.utilized_size = new_utilized_size + + def close_chunk(self): + """Close the chunk. Any tensor can't be appended to a closed chunk later. + """ + # sanity check + assert self.chunk_temp is not None + + # calculate the valid end for each shard + if self.utilized_size <= self.shard_begin: + self.valid_end = 0 + elif self.utilized_size < self.shard_end: + self.valid_end = self.utilized_size - self.shard_begin + + if self.chunk_temp.device.type == 'cpu': + self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.__update_tensors_ptr() + else: + self.cuda_global_chunk = self.chunk_temp + self.chunk_temp = None + + self.__scatter() + # gathered chunk never have shard attribute + if self.keep_gathered: + return + + if self.pin_memory or self.shard_device.type == 'cpu': + self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) + self.cpu_shard.copy_(self.cuda_shard) + self.cpu_vis_flag = True # cpu_shard has been visited + + if self.shard_device.type == 'cpu': + self.cuda_shard = None + + def shard_move(self, device: torch.device, force_copy: bool = False): + """Move the shard tensor in the chunk. + + Args: + device: the device to which the shard will move + force_copy: if True, copy function is called mandatorily + """ + # sanity check + assert not self.is_gathered + # when the current chunk is not synchronized with the optimizer + # just use another way for the movement + if not self.optim_sync_flag: + assert device.type == 'cuda', "each chunk should first be moved to CUDA" + self.__paired_shard_move() + self.optim_sync_flag = True + return + + if device.type == 'cuda': + assert device == get_current_device(), "can't move chunk to another device" + + if self.cuda_shard: + return + + self.cuda_shard = self.cpu_shard.to(get_current_device()) + + if not self.pin_memory: + self.cpu_shard = None + elif device.type == 'cpu': + if self.cuda_shard is None: + return + + if self.pin_memory: + if force_copy or not self.cpu_vis_flag: + self.cpu_shard.copy_(self.cuda_shard) + # if cpu_shard has been visited + # copy operation is not need + else: + self.cpu_shard = self.cuda_shard.cpu() + self.cpu_vis_flag = True + self.cuda_shard = None + else: + raise NotImplementedError + + def access_chunk(self): + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if not self.is_gathered: + self.__gather() + self.__update_tensors_ptr() + + def release_chunk(self): + """Release the usable chunk. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + self.__scatter() + + def reduce(self): + """Reduce scatter all the gradients. It's an operation done in CUDA. + """ + # sanity check + assert self.is_gathered + + if self.pg_size == 1: + # tricky code here + # just move cuda_global_chunk to cuda_shard + # the communication is not necessary + self.__scatter() + elif self.keep_gathered: + # we use all-reduce here + dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) + else: + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + + input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + + free_storage(self.cuda_global_chunk) + self.is_gathered = False + self.__update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + return + self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def get_valid_length(self) -> int: + """Get the valid length of the chunk's payload. + """ + if self.keep_gathered: + return self.utilized_size + else: + return self.valid_end + + def init_pair(self, friend_chunk: 'Chunk') -> None: + """Initialize the paired chunk. + """ + if self.paired_chunk is None and friend_chunk.paired_chunk is None: + self.paired_chunk = friend_chunk + friend_chunk.paired_chunk = self + else: + assert self.paired_chunk is friend_chunk + assert friend_chunk.paired_chunk is self + + def optim_update(self) -> None: + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + """ + # sanity check + assert self.paired_chunk is not None + + friend_chunk = self.paired_chunk + if self.is_gathered is True: + assert friend_chunk.is_gathered is True + self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) + self.optim_sync_flag = True + elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + self.cuda_shard.copy_(friend_chunk.cuda_shard) + self.optim_sync_flag = True + self.cpu_vis_flag = False + else: + # optim_sync_flag is set to False + # see shard_move function for more details + assert friend_chunk.device_type == 'cpu' + assert self.device_type == 'cpu' + self.optim_sync_flag = False + self.cpu_vis_flag = False + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + def __gather(self): + if not self.is_gathered: + # sanity check + assert self.cuda_shard is not None + + alloc_storage(self.cuda_global_chunk) + gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + + self.cuda_shard = None + self.is_gathered = True + + def __scatter(self): + if self.keep_gathered: + return + + if self.is_gathered: + # sanity check + assert self.cuda_shard is None + + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device) + + self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end]) + + free_storage(self.cuda_global_chunk) + self.is_gathered = False + + def __paired_shard_move(self): + assert self.paired_chunk is not None, "chunks should be paired before training" + optim_chunk = self.paired_chunk + assert self.chunk_size == optim_chunk.chunk_size + + # only be called when optimizer state is in CPU memory + # the grad and param should be in the same device + assert self.cuda_shard is None + temp = optim_chunk.cpu_shard.to(get_current_device()) + # avoid to transform FP32 in CPU + self.cuda_shard = temp.to(self.dtype) + + if not self.pin_memory: + self.cpu_shard = None + + def __update_tensors_ptr(self) -> None: + # sanity check + assert self.is_gathered + assert type(self.cuda_global_chunk) == torch.Tensor + + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): + self.tensor_state_cnter[tensor_info.state] -= 1 + tensor_info.state = next_state + self.tensor_state_cnter[tensor_info.state] += 1 + + def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + self.__update_one_tensor_info(tensor_info, next_state) + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + + def __repr__(self, detailed: bool = True): + output = [ + "Chunk Information:\n", + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, + self.pg_size), + "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + ] + + def print_tensor(tensor, prefix=''): + output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, + tensor.device)) + + if self.chunk_temp is not None: + output.append("\tchunk temp:\n") + print_tensor(tensor=self.chunk_temp, prefix='\t\t') + + if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0: + output.append("\tchunk total:\n") + print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t') + + if self.cuda_shard is not None: + output.append("\tcuda shard:\n") + print_tensor(tensor=self.cuda_shard, prefix='\t\t') + + if self.cpu_shard is not None: + output.append("\tcpu shard:\n") + print_tensor(tensor=self.cpu_shard, prefix='\t\t') + + memory_info = self.memory_usage + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + + if detailed: + output.append("\ttensor state monitor:\n") + for st in TensorState: + output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) + + return ''.join(output) diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..07fb6c48b2d7f688f71ee94e01ee226b902a1d5f --- /dev/null +++ b/colossalai/gemini/chunk/manager.py @@ -0,0 +1,239 @@ +from collections import deque +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple + +import torch + +from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState +from colossalai.tensor import ColoTensor +from colossalai.utils import get_current_device + + +class ChunkManager: + """ + A manager class to manipulate the tensors in chunks. + + Args: + chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. + init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + """ + + def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: + + self.device = init_device or get_current_device() + self.dp_degree_chunk_size_dict: Dict[int, int] = dict() + self.kwargs_config = chunk_configuration + for k, v in self.kwargs_config.items(): + self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') + v['init_device'] = self.device + + self.chunk_groups: Dict[str, Deque] = dict() + self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() + self.accessed_chunks: Set[Chunk] = set() + self.accessed_mem: int = 0 + self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} + + def register_tensor(self, + tensor: ColoTensor, + group_type: str, + config_key: int, + cpu_offload: bool = False, + pin_memory: bool = False) -> None: + """ + Register a tensor to the chunk manager. + Then, the tensor should be accessed by `get_chunks`. + + Args: + tensor: the tensor appended to the chunk + group_type: the data type of the group. + config_key: the key of the group's name, the size of the dp world + cpu_offload: if True, the chunk will be closed on CPU + pin_memory: whether the chunk is pinned in the cpu memory + """ + assert tensor not in self.tensor_chunk_map + assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" + assert config_key in self.dp_degree_chunk_size_dict + + chunk_size = self.dp_degree_chunk_size_dict[config_key] + chunk_kwargs = self.kwargs_config[config_key] + group_name = "{}_{}".format(group_type, config_key) + chunk_group = self.__get_chunk_group(group_name) + + try: + # append the tensor to the last chunk + chunk_group[-1].append_tensor(tensor) + except (IndexError, ChunkFullError): + # the except statement will be triggered when there is no chunk or + # the last chunk in the chunk group is full + # this will create a new chunk and allocate this chunk to its corresponding process + if chunk_group: + # the chunk group is not empty + # close the last chunk + self.__close_one_chunk(chunk_group[-1]) + + if tensor.numel() > chunk_size: + chunk_size = tensor.numel() + chunk = Chunk( + chunk_size=chunk_size, + process_group=tensor.process_group, + dtype=tensor.dtype, + cpu_shard_init=cpu_offload, + pin_memory=pin_memory, + **chunk_kwargs, + ) + + chunk_group.append(chunk) + chunk.append_tensor(tensor) + self.__add_memory_usage(chunk.memory_usage) + + self.tensor_chunk_map[tensor] = chunk_group[-1] + + def close_all_groups(self): + """Close all the chunks of all groups. + """ + for group_name in self.chunk_groups: + self.__close_one_chunk(self.chunk_groups[group_name][-1]) + + def access_chunk(self, chunk: Chunk) -> None: + """Make the chunk can be used for calculation. + """ + if chunk in self.accessed_chunks: + return + self.__sub_memroy_usage(chunk.memory_usage) + if chunk.device_type == 'cpu': + chunk.shard_move(get_current_device()) + self.__add_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def release_chunk(self, chunk: Chunk) -> None: + """Scatter the chunk in CUDA. + """ + if chunk not in self.accessed_chunks: + return + if chunk.can_release: + self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + """Move the shard of the chunk to the target device. + """ + if not chunk.can_move or chunk.device_type == device.type: + return + self.__sub_memroy_usage(chunk.memory_usage) + chunk.shard_move(device, force_copy) + self.__add_memory_usage(chunk.memory_usage) + + def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + """Transit tensor state according to pre-defined state machine. + """ + chunk = self.tensor_chunk_map[tensor] + chunk.tensor_trans_state(tensor, state) + + def reduce_chunk(self, chunk: Chunk) -> bool: + """Reduce or all reduce the chunk. + """ + if not chunk.can_reduce: + return False + self.__sub_memroy_usage(chunk.memory_usage) + chunk.reduce() + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + return True + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + """ + Copy data to the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data (torch.Tensor): the tensor to be copied to the chunk + """ + chunk = self.tensor_chunk_map[tensor] + chunk.copy_tensor_to_chunk_slice(tensor, data) + + def get_chunk(self, tensor: torch.Tensor) -> Chunk: + """ + Return the chunk owning the tensor. + + Args: + tensor (torch.Tensor): a torch tensor object + """ + return self.tensor_chunk_map[tensor] + + def get_cuda_movable_chunks(self) -> List[Chunk]: + """ + Get all chunks that can be moved. + """ + chunk_list = [] + for chunk in self.accessed_chunks: + if chunk.can_release: + chunk_list.append(chunk) + chunk_list.sort(key=lambda x: x.count_id) + return chunk_list + + def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + """ + Get all chunks owning the input tensors. + + Args: + tensors (Iterable[torch.Tensor]): the tensors used to look for chunks + """ + chunks = [] + for tensor in tensors: + chunk = self.get_chunk(tensor) + if chunk not in chunks: + chunks.append(chunk) + return tuple(chunks) + + def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: + """Add extern static tensor to chunk manager. + Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. + They are "static", which means their shape, dtype, device never change. + Thus, their memory usage never changes. + + Args: + tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. + """ + assert tensor not in self.tensor_chunk_map + self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + + def __repr__(self) -> str: + msg = [ + 'Chunk Manager Information:\n', + 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + ] + for group_name, group in self.chunk_groups.items(): + msg.append(f'Group {group_name}:\n') + for i, chunk in enumerate(group): + msg.append(f'[{i}] {chunk}\n') + return ''.join(msg) + + def __get_chunk_group(self, group_name: str) -> Deque: + """Register a chunk group. + """ + if group_name not in self.chunk_groups: + self.chunk_groups[group_name] = deque() + return self.chunk_groups[group_name] + + def __close_one_chunk(self, chunk: Chunk): + self.__sub_memroy_usage(chunk.memory_usage) + chunk.close_chunk() + self.__add_memory_usage(chunk.memory_usage) + + def __sub_memroy_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] -= v + + def __add_memory_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] += v + + def __add_accessed_chunk(self, chunk: Chunk): + chunk.access_chunk() + self.accessed_chunks.add(chunk) + self.accessed_mem += chunk.chunk_mem + + def __sub_accessed_chunk(self, chunk: Chunk): + chunk.release_chunk() + self.accessed_chunks.remove(chunk) + self.accessed_mem -= chunk.chunk_mem diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..312d77f1826c2ab0755a0a1d706cfb602f521d81 --- /dev/null +++ b/colossalai/gemini/chunk/search_utils.py @@ -0,0 +1,140 @@ +import math +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch.nn as nn + +from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator +from colossalai.tensor import ColoParameter + + +def in_ddp(param: nn.Parameter) -> bool: + return not getattr(param, '_ddp_to_ignore', False) + + +def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: + """ + Filter those parameters whose size is too large (more than 3x standard deviations) from others. + """ + params_size = [p.numel() for p in model.parameters() if in_ddp(p)] + params_size_arr = np.array(params_size) + + std = np.std(params_size_arr) + mean = np.mean(params_size_arr) + upper_limit = mean + 3 * std + + for key in size_dict: + org_list = size_dict[key] + size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) + + +def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: + """Get unused byte for a certain chunk size. + """ + acc = 0 + left = 0 + for s in size_list: + if s > left: + acc += left + left = chunk_size + left -= s + return left + acc + + +def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]: + """classify_params_by_dp_degree + + Classify the parameters by their dp degree + + Args: + param_order (OrderedParamGenerator): the order of param be visied + + Returns: + Dict[int, List[ColoParameter]]: a dict contains the classification results. + The keys are dp_degrees and the values are parameters. + """ + params_dict: Dict[int, List[ColoParameter]] = dict() + for param in param_order.generate(): + assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" + if not in_ddp(param): + continue + + param_key = param.process_group.dp_world_size() + + if param_key not in params_dict: + params_dict[param_key] = [] + params_dict[param_key].append(param) + + return params_dict + + +def search_chunk_configuration( + model: nn.Module, + search_range_mb: float, + search_interval_byte: int, # hidden size is the best value for the interval + min_chunk_size_mb: float = 32, + filter_exlarge_params: bool = True, + memstas: Optional[MemStats] = None) -> Tuple[Dict, int]: + """search_chunk_configuration + + Args: + model (nn.Module): torch module + search_range_mb (float): searching range in mega byte. + search_interval_byte (int): searching interval in byte. + filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. + + Returns: + Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. + """ + + if memstas is not None: + param_order = memstas.param_order() + else: + # build the param visited order right now + param_order = OrderedParamGenerator() + for p in model.parameters(): + param_order.append(p) + + search_range_byte = round(search_range_mb * 1024**2) + min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) + assert search_range_byte >= 0 + + params_dict = classify_params_by_dp_degree(param_order) + config_dict: Dict[int, Dict] = dict() + + size_dict: Dict[int, List[int]] = dict() + for dp_degree in params_dict: + params_list = params_dict[dp_degree] + size_list = [p.numel() for p in params_list] + # let small parameters keep gathered in CUDA all the time + total_size = sum(size_list) + if total_size < min_chunk_size_byte: + config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True) + else: + size_dict[dp_degree] = size_list + + if filter_exlarge_params: + _filter_exlarge_params(model, size_dict) + + max_size = min_chunk_size_byte + for key in size_dict: + max_size = max(max_size, max(size_dict[key])) + start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) + + min_chunk_waste = float('+inf') + best_chunk_size = start_size + + for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + temp_waste = 0 + for key in size_dict: + temp_waste += _get_unused_byte(size_dict[key], chunk_size) + if temp_waste < min_chunk_waste: + min_chunk_waste = temp_waste + best_chunk_size = chunk_size + + for dp_degree in params_dict: + if dp_degree in config_dict: + continue + config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False) + + return config_dict, min_chunk_waste diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a9f84e7a93fb8de2b71f4871f53c0e835882d0 --- /dev/null +++ b/colossalai/gemini/chunk/utils.py @@ -0,0 +1,59 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration +from colossalai.gemini.memory_tracer import MemStats + + +def init_chunk_manager(model: nn.Module, + init_device: Optional[torch.device] = None, + hidden_dim: Optional[int] = None, + search_range_mb: Optional[float] = None, + min_chunk_size_mb: Optional[float] = None, + filter_exlarge_params: Optional[bool] = None) -> ChunkManager: + + kwargs_dict = dict() + + if hidden_dim: + search_interval_byte = hidden_dim + else: + search_interval_byte = 1024 # 1kb + kwargs_dict["search_interval_byte"] = search_interval_byte + + if search_range_mb: + kwargs_dict["search_range_mb"] = search_range_mb + + if min_chunk_size_mb: + kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb + + if filter_exlarge_params: + kwargs_dict["filter_exlarge_params"] = filter_exlarge_params + + params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)] + total_size = sum(params_sizes) / 1024**2 + + dist.barrier() + begin = time() + + config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) + + dist.barrier() + end = time() + span_s = end - begin + wasted_size /= 1024**2 + + if dist.get_rank() == 0: + print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), + "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), + "total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)), + sep='', + flush=True) + dist.barrier() + + chunk_manager = ChunkManager(config_dict, init_device) + return chunk_manager diff --git a/colossalai/gemini/gemini_context.py b/colossalai/gemini/gemini_context.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7da6b80fbaddc43074d3599bdd0fd18548f94b --- /dev/null +++ b/colossalai/gemini/gemini_context.py @@ -0,0 +1,48 @@ +from enum import EnumMeta + + +class GeminiMemoryManager(object): + + def __init__(self, states_cls: EnumMeta): + super().__init__() + self.states_cls = states_cls + self._cnter = 0 # the counter of instances + + self.total_mem = dict() + self.state_mem = dict() + self.state_mem['cpu'] = dict() + self.state_mem['cuda'] = dict() + + self.reset() + + @property + def total_number(self): + return self._cnter + + def reset(self): + self._cnter = 0 # the counter of instances + + self.total_mem['cpu'] = 0 # memory occupation of instances in cpu + self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda + + # memory conditions for all states + for state in self.states_cls: + self.state_mem['cpu'][state] = 0 + self.state_mem['cuda'][state] = 0 + + def register_new_instance(self): + self._cnter += 1 + + def delete_instance(self): + self._cnter -= 1 + + def print_info(self): + print(f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", + sep='\n') + + for state in self.states_cls: + print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", + sep='\n') diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..541762a72acf7678388db10bfb342db4ede15371 --- /dev/null +++ b/colossalai/gemini/gemini_mgr.py @@ -0,0 +1,156 @@ +import functools +from time import time +from typing import List, Optional, Tuple + +import torch + +from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini.memory_tracer import MemStats + +from .memory_tracer import ChunkMemStatsCollector +from .placement_policy import PlacementPolicyFactory + + +class GeminiManager: + """ + Stateful Tensor Manager, inspired from PatrickStar + + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + https://arxiv.org/abs/2108.05818 + + Args: + placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. + If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. + If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. + Note that 'auto' policy can only work well when no other processes use CUDA during your training. + chunk_manager (ChunkManager): A ``ChunkManager`` instance. + memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. + """ + + def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + + assert placement_policy in PlacementPolicyFactory.get_polocy_names() + self.policy_name = placement_policy + policy_cls = PlacementPolicyFactory.create(placement_policy) + self._chunk_manager = chunk_manager + + self._premade_memstats_ = memstats is not None + self._memstats = memstats + self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, + self._memstats) if policy_cls.need_mem_stats else None + self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) + self._compute_list: List[Tuple[Chunk, ...]] = [] + self._compute_idx: int = -1 + + self._h2d_volume = 0 + self._d2h_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._warmup = True + self._comp_cuda_demand_time = 0 + + def memstats(self): + """memstats + + get the memory statistics during training. + The stats could be collected by a runtime memory tracer, or collected by the GeminiManager. + Note, for the latter, you can not access the memstats before warmup iteration finishes. + """ + if self._premade_memstats_: + return self._memstats + else: + assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup." + return self._mem_stats_collector._memstats + + def pre_iter(self, *args): + if self._mem_stats_collector and self._warmup: + self._mem_stats_collector.start_collection() + + def post_iter(self): + """This function must be called when each iteration finishes + """ + if self._mem_stats_collector and self._warmup: + self._mem_stats_collector.finish_collection() + self._warmup = False + self._compute_idx = -1 + self._h2d_volume = 0 + self._d2h_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._comp_cuda_demand_time = 0 + + def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + """ Adjust the layout of stateful tensors according to the information provided + by mem_stats_collector, which should belongs to a Sharded Model. + """ + # find stateful tensor in state COMPUTE + start = time() + self._record_chunks_order(chunks) + cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) + self._layout_time += time() - start + + vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx) + + self._d2h_volume += vol + self._evict_time += evict_time + # move COMPUTE tensors to CUDA + self._h2d_volume += cuda_demand + + @functools.lru_cache(maxsize=None) + def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): + start = time() + cuda_demand = 0 + for chunk in chunks: + if chunk.device_type == 'cuda': + if chunk.is_gathered: + pass + else: + cuda_demand += chunk.chunk_mem - chunk.shard_mem + elif chunk.device_type == 'cpu': + cuda_demand += chunk.chunk_mem + else: + raise RuntimeError + self._comp_cuda_demand_time += time() - start + + can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() + return cuda_demand, can_evict_chunks + + def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + self._compute_idx += 1 + if self._warmup and self._placement_policy.need_mem_stats: + self._compute_list.append(chunks) + + @property + def default_device(self): + return self._placement_policy.get_default_device() + + def sample_overall_data(self): + if self._mem_stats_collector: + self._mem_stats_collector.sample_overall_data() + + def record_model_data_volume(self): + if self._mem_stats_collector: + self._mem_stats_collector.record_model_data_volume() + + @property + def chunk_manager(self): + return self._chunk_manager + + @property + def cuda_margin_mem(self) -> Optional[float]: + if self._mem_stats_collector: + return self._mem_stats_collector.cuda_margin_mem + return None + + @property + def is_cuda_margin_mem_avail(self) -> bool: + return self._placement_policy.need_mem_stats + + @staticmethod + def get_default_device(policy_name: str) -> torch.device: + return PlacementPolicyFactory.get_default_device(policy_name) diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02c9d5754ec9a34c11531111a8fd6ca5e6698c96 --- /dev/null +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -0,0 +1,11 @@ +from .param_runtime_order import OrderedParamGenerator # isort:skip +from .memory_stats import MemStats # isort:skip +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip +from .static_memstats_collector import StaticMemStatsCollector # isort:skip + +__all__ = [ + 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', + 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator' +] diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..44c11302e89758c794bfa8371545b577a8b2545a --- /dev/null +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -0,0 +1,36 @@ +from typing import Optional + +from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini.memory_tracer import MemStats +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity + +from .memstats_collector import MemStatsCollector + + +class ChunkMemStatsCollector(MemStatsCollector): + + def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + """ + + Memory Statistic Collector for Chunks. + + Args: + chunk_manager (ChunkManager): the chunk manager. + memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None. + """ + super().__init__(memstats) + self._chunk_manager = chunk_manager + + # override + def record_model_data_volume(self) -> None: + """ + record model data volumn on cuda and cpu. + """ + if self._start_flag and not self.use_outside_memstats: + cuda_mem = self._chunk_manager.total_mem['cuda'] + self._memstats.record_max_cuda_model_data(cuda_mem) + + @property + def cuda_margin_mem(self) -> float: + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda') diff --git a/colossalai/gemini/memory_tracer/memory_monitor.py b/colossalai/gemini/memory_tracer/memory_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d99dbce7a43a8089dd1ebddd9bce6979a17f40 --- /dev/null +++ b/colossalai/gemini/memory_tracer/memory_monitor.py @@ -0,0 +1,147 @@ +import json +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from time import sleep, time + +import torch + +from colossalai.utils import colo_device_memory_used, get_current_device + + +class MemoryMonitor: + """Base class for all types of memory monitor. + All monitors should have a list called `time_stamps` and a list called `mem_stats`. + """ + + def __init__(self): + self.time_stamps = [] + self.mem_stats = [] + + def __len__(self): + return len(self.mem_stats) + + @abstractmethod + def start(self): + pass + + @abstractmethod + def finish(self): + pass + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "w") as f: + json.dump(self.state_dict(), f) + + def clear(self): + self.mem_stats.clear() + self.time_stamps.clear() + + +class AsyncMemoryMonitor(MemoryMonitor): + """ + An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + at interval of `1/(10**power)` sec. + + The idea comes from Runtime Memory Tracer of PatrickStar + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + Usage:: + + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + + Args: + power (int, optional): the power of time interva. Defaults to 10. + + .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, power: int = 10): + super().__init__() + self.keep_measuring = False + + current_device = get_current_device() + + def _set_cuda_device(): + torch.cuda.set_device(current_device) + + self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) + self.monitor_thread = None + self.interval = 1 / (10**power) + + def set_interval(self, power: int): + self.clear() + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + + self.keep_measuring = False + max_usage = self.monitor_thread.result() + + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + while self.keep_measuring: + max_usage = max( + max_usage, + colo_device_memory_used(get_current_device()), + ) + sleep(self.interval) + return max_usage + + +class SyncCudaMemoryMonitor(MemoryMonitor): + """ + A synchronized cuda memory monitor. + It only record the maximum allocated cuda memory from start point to finish point. + """ + + def __init__(self, power: int = 10): + super().__init__() + + def start(self): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + def finish(self) -> int: + """ + return max gpu memory used since latest `start()`. + + Returns: + int: max GPU memory + """ + torch.cuda.synchronize() + self.time_stamps.append(time()) + max_usage = torch.cuda.max_memory_allocated() + self.mem_stats.append(max_usage) + return max_usage diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8390e025fc43370aca82dbe8ab9121afa187a1 --- /dev/null +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -0,0 +1,135 @@ +from typing import Any, Dict, List, Optional + +import torch + +from colossalai.gemini.memory_tracer import OrderedParamGenerator + + +class MemStats(object): + + def __init__(self) -> None: + """ + Store the non model data statistics used for Gemini and ZeroOptimizer. + """ + # (preop_step, List[param]) + self._step_param_dict = dict() + # (param, List[preop_step]) + self._param_step_dict = dict() + # (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1) + self._step_nmd_dict = dict() + self._param_runtime_order = OrderedParamGenerator() + + self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._max_overall_cuda = 0 + self._prev_md_cuda = -1 + + # old version + self._model_data_cuda_list = [] + self._model_data_cpu_list = [] + + self._overall_cuda_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cuda_list = [] + self._non_model_data_cpu_list = [] + + def calc_max_cuda_non_model_data(self): + if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1: + max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda + self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data + # compatibility of the old version. + self._non_model_data_cuda_list.append(max_cuda_non_model_data) + + def record_max_cuda_model_data(self, val): + self._prev_md_cuda = val + + def record_max_cuda_overall_data(self, val): + self._prev_overall_cuda = val + self._max_overall_cuda = max(self._max_overall_cuda, val) + + @property + def max_overall_cuda(self): + return self._max_overall_cuda + + def increase_preop_step(self, param_list: List[torch.nn.Parameter]): + """ + the time step is increased. param list is used between current and the next + time step. + + Args: + param_list (List[torch.nn.Parameter]): a list of torch paramters. + """ + for p in param_list: + if p not in self._param_step_dict: + self._param_step_dict[p] = [self._preop_step] + else: + self._param_step_dict[p].append(self._preop_step) + self._param_runtime_order.append(p) + self._step_param_dict[self._preop_step] = param_list + self._preop_step += 1 + + def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]: + """param_used_step + get the timestep list using the param + + Args: + param (torch.nn.Parameter): a torch param + + Returns: + Optional[List[int]]: a list of int indicates the time step of preop hook. + """ + if param not in self._param_step_dict: + return None + else: + return self._param_step_dict[param] + + def param_order(self): + if self._param_runtime_order.is_empty(): + raise RuntimeError + else: + return self._param_runtime_order + + def non_model_data_list(self, device_type: str) -> List[int]: + if device_type == 'cuda': + return self._non_model_data_cuda_list + elif device_type == 'cpu': + return self._non_model_data_cpu_list + else: + raise TypeError + + def max_non_model_data(self, device_type: str) -> float: + if device_type == 'cuda': + return max(self._non_model_data_cuda_list) + elif device_type == 'cpu': + return max(self._non_model_data_cpu_list) + else: + raise TypeError + + def max_overall_cuda(self, device_type: str) -> float: + if device_type == 'cuda': + return max(self._overall_cuda_list) + elif device_type == 'cpu': + return max(self._overall_cpu_list) + else: + raise TypeError + + def clear(self): + self._model_data_cuda_list = [] + self._overall_cuda_list = [] + + self._model_data_cpu_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cpu_list = [] + self._non_model_data_cuda_list = [] + + self._param_runtime_order.clear() + self._step_param_dict.clear() + self._param_step_dict.clear() + self._step_nmd_dict.clear() + self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._prev_md_cuda = -1 diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..d521fe21231ca0a5fc94b4a90307cb133db6aa60 --- /dev/null +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -0,0 +1,104 @@ +import time +from typing import List, Optional + +import torch + +from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.utils.memory import colo_device_memory_used + +from .memory_stats import MemStats + + +class MemStatsCollector: + """ + A Memory statistic collector. + It works in two phases. + Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU. + The first iteration of DNN training. + Phase 2. Runtime Phase: use the read-only collected stats + The rest iterations of DNN training. + + It has a Sampling counter which is reset after DNN training iteration. + """ + + def __init__(self, memstats: Optional[MemStats] = None) -> None: + self._mem_monitor = SyncCudaMemoryMonitor() + self._sampling_time = [] + + self._start_flag = False + self._step_idx = 0 + self._step_total = 0 + if memstats is not None: + self.use_outside_memstats = True + self._memstats = memstats + else: + self.use_outside_memstats = False + self._memstats = MemStats() + + def next_period_non_model_data_usage(self, device_type: str) -> int: + """Maximum non model data memory usage during the next Op run + + Args: + device_type (str): device type, can be 'cpu' or 'cuda'. + + Returns: + int: max non model data memory usage of current sampling period + """ + assert not self._start_flag, 'Cannot get mem stats info during collection phase.' + assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + f"step total {self._step_total}" + next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] + self._step_idx = (self._step_idx + 1) % self._step_total + return next_non_model_data + + @property + def sampling_time(self): + return [t - self._sampling_time[0] for t in self._sampling_time] + + def start_collection(self): + print('start collection') + self._start_flag = True + self._mem_monitor.start() + + def finish_collection(self): + self.sample_overall_data() + # self._step_total = len(self._sampling_time) + self._step_total = len(self._memstats.non_model_data_list('cuda')) + self._start_flag = False + self._mem_monitor.finish() + print(f'finish_collection {self._step_total}') + + # deprecated + def record_model_data_volume(self) -> None: + """ + Sampling model data statistics. + """ + if self._start_flag and not self.use_outside_memstats: + # The following code work for ZeroInitContext, which is deprecated in v0.1.12 + cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + self._memstats.append_model_data('cuda', cuda_mem) + self._memstats.append_model_data('cpu', cpu_mem) + + def sample_overall_data(self) -> None: + """ + Sampling overall and non model data cuda memory statistics. + """ + if self._start_flag and not self.use_outside_memstats: + cuda_overall = self._mem_monitor.finish() + self._memstats.record_max_cuda_overall_data(cuda_overall) + self._memstats.calc_max_cuda_non_model_data() + + self._mem_monitor.start() + + if self._start_flag: + self._sampling_time.append(time.time()) + + def clear(self) -> None: + self._memstats.clear() + self._start_flag = False + self._step_idx = 0 + self._step_total = 0 diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/gemini/memory_tracer/param_runtime_order.py new file mode 100644 index 0000000000000000000000000000000000000000..638c0533ce926b6629906d8b113161345017295d --- /dev/null +++ b/colossalai/gemini/memory_tracer/param_runtime_order.py @@ -0,0 +1,42 @@ +from abc import ABC + +import torch + + +class ParamGenerator(ABC): + + def append(self, param: torch.nn.Parameter): + pass + + def generate(self): + pass + + def clear(self): + pass + + +class OrderedParamGenerator(ParamGenerator): + """OrderedParamGenerator + + Contain the order of parameters visited during runtime. + """ + + def __init__(self) -> None: + self.param_visited_order = [] + + def append(self, param: torch.nn.Parameter): + self.param_visited_order.append(param) + + def generate(self): + visited_set = set() + for p in self.param_visited_order: + if p not in visited_set: + yield p + visited_set.add(p) + del visited_set + + def is_empty(self): + return len(self.param_visited_order) == 0 + + def clear(self): + self.param_visited_order = [] diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..a643751da7e2df39bf7e9195707971a31d9cba0f --- /dev/null +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -0,0 +1,99 @@ +import torch.nn + +from colossalai.gemini.memory_tracer import MemStats +from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook +from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.tensor.param_op_hook import ColoParamOpHookManager + +__all__ = ['RuntimeMemTracer'] + + +class RuntimeMemTracer(): + """RuntimeMemTracer for the module training using ColoParameter. + + Trace non-model memory usage during fwd+bwd process. + It is obtained by using a tensor with the same shape as the training process as the inputs + and running an single fwd+bwd to trace the statistics. + + NOTE() + 1. The premise to use this tracer is that the target DNN execute the same operations at each iterations, + 2. Module buffers are viewed as non-model data. + """ + + def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): + super().__init__() + self.module = module + self.dtype = dtype + self._gradstat = GradMemStats() + self._memstats = MemStats() + self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat) + self.grad_hook = GradMemTracerHook(self._gradstat) + self.cpu_param_data_dict = {} + + for p in module.parameters(): + p.data = p.data.to(dtype) + + self._cast_buffers_to_cuda_dtype() + + def parameters_in_runtime_order(self): + return self._memstats._param_runtime_order.generate() + + def memstats(self): + return self._memstats + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _backup_params(self): + """ + The function is called before forward. Backup model params on cpu. + """ + for p in self.module.parameters(): + self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu") + self.cpu_param_data_dict[p].copy_(p.data) + + def _restore_params(self): + """ + This function is called after backward. Restore model params. + """ + for p in self.module.parameters(): + p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad) + p.data.copy_(self.cpu_param_data_dict[p]) + self.cpu_param_data_dict.clear() + + def _pre_forward(self): + self._clear_cuda_mem_info() + self._backup_params() + self.grad_hook.register_grad_hook(self.module) + self.param_op_hook.mem_monitor.start() + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype) + self.module.zero_grad(set_to_none=True) + self._pre_forward() + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + return outputs + + def backward(self, loss): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def _post_backward(self): + cuda_volume = self.param_op_hook.mem_monitor.finish() + self._memstats.record_max_cuda_overall_data(cuda_volume) + # calc the last Op non model data + self._memstats.calc_max_cuda_non_model_data() + self.grad_hook.remove_grad_hook() + self._restore_params() + + def _clear_cuda_mem_info(self): + self._memstats.clear() + self._gradstat.clear() + + def _cast_buffers_to_cuda_dtype(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.data.to(self.dtype) diff --git a/colossalai/gemini/memory_tracer/static_memstats_collector.py b/colossalai/gemini/memory_tracer/static_memstats_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..3209881e100cee15011ea816323cc0141eccd0a8 --- /dev/null +++ b/colossalai/gemini/memory_tracer/static_memstats_collector.py @@ -0,0 +1,105 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.fx import symbolic_trace + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta +from colossalai.gemini.chunk import ChunkManager + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +from .chunk_memstats_collector import ChunkMemStatsCollector + + +class ModuleInfos: + + def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, + parent_module: torch.nn.Module): + self.module = module + self.module_name = module_name + self.module_full_name = module_full_name + self.parent_module = parent_module + + +class StaticMemStatsCollector(ChunkMemStatsCollector): + """ + A Static Memory statistic collector. + """ + + def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: + super().__init__(chunk_manager) + self.module = module + self.module_info_list = [] + + def init_mem_stats(self, *inputs): + + self.register_opnodes_recursively(self.module) + self.refactor_module() + + self.module = self.module.cpu() + self.module.train() + + data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + gm = symbolic_trace(self.module) + interp = MetaInfoProp(gm) + interp.propagate(*data) + + total_mem = 0 + for inp in inputs: + total_mem += inp.numel() * inp.element_size() + last_node = None + module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list] + for node in gm.graph.nodes: + total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem) + last_node = node + self._non_model_data_cuda_list.append(total_mem) + self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:] + + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = last_node.meta["fwd_mem_out"] + for node in gm.graph.nodes.__reversed__(): + cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node) + cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd) + total_mem = total_mem - cur_module_mem_fwd + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = node.meta["bwd_mem_out"] + + self._step_total = len(self._non_model_data_cuda_list) + self.recover_module() + + def refactor_module(self): + for modInfo in self.module_info_list: + temp_node = nn.Sequential(nn.ReLU(), modInfo.module) + modInfo.parent_module.__setattr__(modInfo.module_name, temp_node) + + def recover_module(self): + for modInfo in self.module_info_list: + modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) + + def register_opnodes_recursively(self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None): + + assert isinstance(module, torch.nn.Module) + + for child_name, child in module.named_children(): + self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module)) diff --git a/colossalai/gemini/memory_tracer/utils.py b/colossalai/gemini/memory_tracer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6962c058110e245a6bbd2470d75c137b54202aae --- /dev/null +++ b/colossalai/gemini/memory_tracer/utils.py @@ -0,0 +1,59 @@ +from typing import Optional, Tuple + +import torch + + +def colo_model_optimizer_usage(optim) -> Tuple[int, int]: + """Trace the optimizer memory usage + + Args: + optim (ShardedOptimV2): an instance of ShardedOptimver + + Returns: + Tuple[int, int]: cuda/cpu memory usage in Byte + """ + if optim is None: + return 0, 0 + assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()" + return optim.get_memory_usage() + + +def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: + """ + Trace the model memory usage. + Args: + model (torch.nn.Module): a torch model + + Returns: + Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte + """ + if model is None: + return 0, 0 + + def _get_tensor_mem_use(t: Optional[torch.Tensor]): + if t is None: + return 0, 0 + assert isinstance(t, torch.Tensor) + _cpu_mem_usage, _cuda_mem_usage = 0, 0 + if t.device.type == 'cpu': + _cpu_mem_usage += t.numel() * t.element_size() + elif t.device.type == 'cuda': + _cuda_mem_usage += t.numel() * t.element_size() + return _cuda_mem_usage, _cpu_mem_usage + + cuda_mem_usage = 0 + cpu_mem_usage = 0 + for param in model.parameters(): + if hasattr(param, 'colo_attr'): + t_cuda, t_cpu = param.colo_attr.get_memory_usage() + cuda_mem_usage += t_cuda + cpu_mem_usage += t_cpu + else: + t_cuda, t_cpu = _get_tensor_mem_use(param.data) + cuda_mem_usage += t_cuda + cpu_mem_usage += t_cpu + t_cuda, t_cpu = _get_tensor_mem_use(param.grad) + cuda_mem_usage += t_cuda + cpu_mem_usage += t_cpu + + return cuda_mem_usage, cpu_mem_usage diff --git a/colossalai/gemini/ophooks/__init__.py b/colossalai/gemini/ophooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b65726166644f05b1972b88a7c0358fb7eeb737f --- /dev/null +++ b/colossalai/gemini/ophooks/__init__.py @@ -0,0 +1,3 @@ +from .utils import BaseOpHook, register_ophooks_recursively + +__all__ = ["BaseOpHook", "register_ophooks_recursively"] diff --git a/colossalai/gemini/ophooks/_shard_grad_ophook.py b/colossalai/gemini/ophooks/_shard_grad_ophook.py new file mode 100644 index 0000000000000000000000000000000000000000..5115ff74da16b224c86ddecc29d1e4e470d01046 --- /dev/null +++ b/colossalai/gemini/ophooks/_shard_grad_ophook.py @@ -0,0 +1,32 @@ +import torch + +from colossalai.registry import OPHOOKS + +from . import BaseOpHook + + +@OPHOOKS.register_module +class ShardGradMemTracerHook(BaseOpHook): + """ + A hook to process sharded param before and afther FWD and BWD operator executing. + """ + + def __init__(self): + super().__init__() + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(): + assert hasattr(param, '_sharded_grad') + param._sharded_grad.setup() + + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + def post_iter(self): + pass diff --git a/colossalai/gemini/ophooks/_shard_param_ophook.py b/colossalai/gemini/ophooks/_shard_param_ophook.py new file mode 100644 index 0000000000000000000000000000000000000000..57f76970cc8631955c06ee244dd8d939240b9a88 --- /dev/null +++ b/colossalai/gemini/ophooks/_shard_param_ophook.py @@ -0,0 +1,47 @@ +import torch +from colossalai.registry import OPHOOKS + +from . import BaseOpHook + + +@OPHOOKS.register_module +class ShardParamHook(BaseOpHook): + """ + A hook to process sharded param before and afther FWD and BWD operator executing. + """ + + def __init__(self): + super().__init__() + + def niter(self): + return self._niter + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.gather() + param.data = param.ca_attr.payload() + + def post_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.shard() + param.data = param.ca_attr.payload() + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.gather() + param.data = param.ca_attr.payload() + + def post_bwd_exec(self, module: torch.nn.Module, input): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.shard() + param.data = param.ca_attr.payload() + + def pre_iter(self): + pass + + def post_iter(self): + pass diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0df4e615ca03c19ad18344f29d6858ade8bf65 --- /dev/null +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -0,0 +1,145 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor +from colossalai.gemini.tensor_utils import alloc_storage, free_storage +from colossalai.tensor.param_op_hook import ColoParamOpHook + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class GradMemStats(): + + def __init__(self) -> None: + self.unreleased_grad_flag = {} + self.unreleased_grad_volume = 0 + + def clear(self): + self.unreleased_grad_flag.clear() + self.unreleased_grad_volume = 0 + + +class GradMemTracerHook(): + + def __init__(self, grad_stats: GradMemStats): + self.grad_hook_list = [] + self._grad_stats = grad_stats + + def grad_handle(self, p, grad): + assert self._grad_stats.unreleased_grad_flag[p] + free_storage(grad) + self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size() + self._grad_stats.unreleased_grad_flag[p] = False + + def register_grad_hook(self, module: torch.nn.Module): + for p in module.parameters(): + if p.requires_grad: + self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) + self._grad_stats.unreleased_grad_flag[p] = False + + def remove_grad_hook(self): + for hook in self.grad_hook_list: + hook.remove() + + +class ParamMemTracerHook(ColoParamOpHook): + + def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None: + super().__init__() + self._training_phase = TrainingPhase.FORWARD + self._memstats = memstats + self._grad_stats = gradstats + self.mem_monitor = SyncCudaMemoryMonitor() + + def _free_cuda_params(self, params): + for p in params: + if p.data.device.type == "cpu": + raise NotImplementedError("Only free cuda memory") + free_storage(p.data) + + def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]): + """ + move params to cuda + + Args: + params (List[torch.nn.Parameter]): target params + + Raises: + NotImplementedError: raise error when param has cpu grad + """ + for p in params: + cur_dev = p.data.device.type + if cur_dev == "cpu": + if p.grad is not None and p.grad.device.type == "cpu": + raise NotImplementedError("Only run in forward propagation") + p.data = torch.empty(p.data.shape, + device="cuda", + dtype=p.data.dtype, + requires_grad=p.data.requires_grad) + elif cur_dev == "cuda": + alloc_storage(p.data) + + def record_model_data_volume(self, params): + """ + get cuda model data used by params + """ + data_volume = self._grad_stats.unreleased_grad_volume + for p in params: + cur_model_data_volume = p.data.numel() * p.data.element_size() + data_volume += cur_model_data_volume + if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad: + # add param.grad, actually param.grad is None in this time + data_volume += cur_model_data_volume + if not self._grad_stats.unreleased_grad_flag[p]: + self._grad_stats.unreleased_grad_volume += cur_model_data_volume + self._grad_stats.unreleased_grad_flag[p] = True + # record max non model data used for this Op + self._memstats.record_max_cuda_model_data(data_volume) + + def pre_op(self, params): + max_cuda_used_pre_op = self.mem_monitor.finish() + # record max cuda overall data for prev OP. + self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op) + # record max cuda non model data for prev OP. + self._memstats.calc_max_cuda_non_model_data() + + self._allocate_params_on_cuda(params) + # record max cuda model data for current OP + self.record_model_data_volume(params) + + self.mem_monitor.start() + self._memstats.increase_preop_step(params) + + def post_op(self, params): + self._free_cuda_params(params) + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) diff --git a/colossalai/gemini/ophooks/utils.py b/colossalai/gemini/ophooks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe08405c82bf2a97ae306b1dd871e1c0b73afc6f --- /dev/null +++ b/colossalai/gemini/ophooks/utils.py @@ -0,0 +1,142 @@ +import torch +from typing import List, Callable, Optional + +from abc import ABC, abstractmethod +import torch + + +class BaseOpHook(ABC): + """This class allows users to add customized operations + before and after the execution of a PyTorch submodule""" + + def __init__(self): + pass + + @abstractmethod + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + pass + + @abstractmethod + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + @abstractmethod + def post_iter(self): + pass + + +# apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, functional, backward_function, output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +class PreBackwardFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + module.applied_pre_backward = False + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + output = output.detach() + ctx.pre_backward_function = pre_backward_function + return output + + @staticmethod + def backward(ctx, *args): + """ + Args: + activation_grad of the next layer. + Returns: + grad of the input activation. + """ + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +def register_ophooks_recursively(module: torch.nn.Module, + ophook_list: List[BaseOpHook], + name: str = "", + filter_fn: Optional[Callable] = None): + r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" + assert isinstance(module, torch.nn.Module) + assert isinstance(ophook_list, (list, tuple)) + assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' + for hook in ophook_list: + assert (isinstance(hook, BaseOpHook)) + + # Add hooks for submodules + for child_name, child in module.named_children(): + register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + # return from flitered module + if filter_fn is not None and filter_fn(module): + return + + def _pre_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_fwd_exec(submodule, *args) + + def _post_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_fwd_exec(submodule, *args) + + def _pre_backward_module_hook(submodule, inputs, output): + + def _run_before_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_bwd_exec(submodule, inputs, output) + + return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) + + def _post_backward_module_hook(submodule, inputs): + + def _run_after_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_bwd_exec(submodule, inputs) + + return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs) + + module.register_forward_pre_hook(_pre_forward_module_hook) + module.register_forward_hook(_post_forward_module_hook) + + module.register_forward_hook(_pre_backward_module_hook) + module.register_forward_pre_hook(_post_backward_module_hook) diff --git a/colossalai/gemini/paramhooks/__init__.py b/colossalai/gemini/paramhooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e423993784afb0b25e0b9d068e4968763919d03 --- /dev/null +++ b/colossalai/gemini/paramhooks/__init__.py @@ -0,0 +1,3 @@ +from ._param_hookmgr import BaseParamHookMgr + +__all__ = ["BaseParamHookMgr"] diff --git a/colossalai/gemini/paramhooks/_param_hookmgr.py b/colossalai/gemini/paramhooks/_param_hookmgr.py new file mode 100644 index 0000000000000000000000000000000000000000..ee57cb46a90d5604f5f9b4e0a80fc4701fcba23c --- /dev/null +++ b/colossalai/gemini/paramhooks/_param_hookmgr.py @@ -0,0 +1,38 @@ +from typing import Callable, List +import torch +import functools + + +class BaseParamHookMgr(object): + + def __init__(self, param_list: List[torch.nn.Parameter]) -> None: + r""" + register backward hook on every parameters of module + """ + self._param_list = param_list + self._hook_list = [] + + def register_backward_hooks(self, hook_call: Callable) -> None: + r""" + The hook_call will be called every time a gradient with respect to the a param in self.param_list + is computed. + The hook should have the following signature: + ``` + hook(param, grad) -> Tensor or None + ``` + """ + if not torch.is_grad_enabled(): + return # don't register grad hooks if grad isn't enabled + for p in self._param_list: + if p.requires_grad and not hasattr(p, '_base_param_hook'): + handle = p.register_hook(functools.partial(hook_call, p)) + p._base_param_hook = handle + + def remove_hooks(self) -> None: + """ + Remove hooks from model parameters. + """ + + for p in self._param_list: + if p.requires_grad and hasattr(p, '_base_param_hook'): + p._base_param_hook.remove() diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..50004ec35f3aa6a76282c403c1c957e6ce58e342 --- /dev/null +++ b/colossalai/gemini/placement_policy.py @@ -0,0 +1,245 @@ +import functools +from abc import ABC, abstractmethod +from time import time +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini.memory_tracer import ChunkMemStatsCollector +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity + + +class PlacementPolicy(ABC): + need_mem_stats: bool = False + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + self.chunk_manager = chunk_manager + self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector + + @abstractmethod + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + raise NotImplementedError + + @staticmethod + def get_default_device() -> torch.device: + return torch.device('cpu') + + +class CPUPlacementPolicy(PlacementPolicy): + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + volume = 0 + start = time() + for chunk in can_evict_chunks: + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + volume += chunk.chunk_mem + return volume, time() - start + + +class CUDAPlacementPolicy(PlacementPolicy): + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + return 0, 0 + + @staticmethod + def get_default_device() -> torch.device: + return get_current_device() + + +class AutoPlacementPolicy(PlacementPolicy): + + need_mem_stats: bool = True + # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() + # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + _warmup_non_model_data_ratio: float = 0.8 + _steady_cuda_cap_ratio: float = 0.9 + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs) -> Tuple[int, float]: + """ + Evict tensors from CUDA device. + + Args: + can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted. + cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. + warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. + compute_list (List[StatefulTensor], optional): TODO. Defaults to []. + compute_idx (int, optional): the idx of computing device. Defaults to 0. + + Raises: + RuntimeError: + + Returns: + int: the volume of memory that is evicted + """ + start = time() + cuda_capacity = colo_device_memory_capacity(get_current_device()) + used_cuda_model_data = self.chunk_manager.total_mem['cuda'] + if warmup: + # We designate a part of CUDA memory for model data in warmup iterations. + max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio + else: + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + freed_cuda_model_data = 0 + + if avail_cuda_model_data < cuda_demand: + # Move cuda_demand - avail_cuda_model_data volume of tensors + # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_chunks = can_evict_chunks + if not warmup: + to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) + # print(self._sort_can_evict_chunks.cache_info()) + for chunk in to_free_chunks: + if freed_cuda_model_data >= to_free_cuda_model_data: + break + + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + freed_cuda_model_data += chunk.chunk_mem + if freed_cuda_model_data < to_free_cuda_model_data: + raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") + return freed_cuda_model_data, time() - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} + for i in range(len(compute_list) - 1, compute_idx, -1): + for chunk in compute_list[i]: + if chunk in next_compute_idx: + next_compute_idx[chunk] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + @staticmethod + def set_warmup_non_model_data_ratio(ratio: float) -> None: + ratio = float(ratio) + assert 0.0 < ratio < 1.0 + AutoPlacementPolicy._warmup_non_model_data_ratio = ratio + + @staticmethod + def set_steady_cuda_cap_ratio(ratio: float) -> None: + ratio = float(ratio) + assert 0.0 < ratio < 1.0 + AutoPlacementPolicy._steady_cuda_cap_ratio = ratio + + +class ConstPlacementPolicy(PlacementPolicy): + + need_mem_stats: bool = False + _accessed_memory_boundary = 512 * 1024**2 + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs) -> Tuple[int, float]: + """ + See the docstrings in the class `AutoPlacementPolicy`. + """ + start = time() + used_accessed_memory = self.chunk_manager.accessed_mem + avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory + freed_accessed_memory = 0 + + if avail_accessed_memory < cuda_demand: + to_free_memory = cuda_demand - avail_accessed_memory + to_free_chunks = can_evict_chunks + + if not warmup: + # sort all chunks + to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) + + for chunk in to_free_chunks: + if freed_accessed_memory >= to_free_memory: + break + + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + freed_accessed_memory += chunk.chunk_mem + + if freed_accessed_memory < to_free_memory: + raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_memory}, freed {freed_accessed_memory}") + return freed_accessed_memory, time() - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} + for i in range(len(compute_list) - 1, compute_idx, -1): + for chunk in compute_list[i]: + if chunk in next_compute_idx: + next_compute_idx[chunk] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + @staticmethod + def set_const_memory_boundary(cuda_memory_mb: int) -> None: + boundary = int(cuda_memory_mb * 1024**2) + assert boundary > 0 + ConstPlacementPolicy._accessed_memory_boundary = boundary + + +class PlacementPolicyFactory: + policies: Dict[str, Type[PlacementPolicy]] = { + 'cpu': CPUPlacementPolicy, + 'cuda': CUDAPlacementPolicy, + 'auto': AutoPlacementPolicy, + 'const': ConstPlacementPolicy + } + + @staticmethod + def create(policy_name: str) -> Type[PlacementPolicy]: + if policy_name not in PlacementPolicyFactory.policies: + raise TypeError(f"Unknown tensor placement policy {policy_name}") + return PlacementPolicyFactory.policies[policy_name] + + @staticmethod + def get_polocy_names(): + return tuple(PlacementPolicyFactory.policies.keys()) + + @staticmethod + def get_default_device(policy_name: str) -> torch.device: + policy_cls = PlacementPolicyFactory.create(policy_name) + return policy_cls.get_default_device() diff --git a/colossalai/gemini/stateful_tensor.py b/colossalai/gemini/stateful_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..18fc8fd14d3c8ad173a60e20cf3d32a3993ee744 --- /dev/null +++ b/colossalai/gemini/stateful_tensor.py @@ -0,0 +1,209 @@ +from enum import Enum +from typing import Optional +import torch +from typing import Union + +from colossalai.gemini.gemini_context import GeminiMemoryManager + + +def sizeof_tensor(tensor: torch.Tensor): + return tensor.numel() * tensor.element_size() + + +class TensorState(Enum): + FREE = 0 + HOLD = 1 + HOLD_AFTER_FWD = 2 + HOLD_AFTER_BWD = 3 + COMPUTE = 4 + + +class StatefulTensor(object): + """A Structure stores a Torch Tensor and labeled states. + Inspired from the paper: + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + + https://arxiv.org/abs/2108.05818 + """ + # Global Stateful Tensor Manager + GST_MGR = GeminiMemoryManager(TensorState) + + def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: + self._state = state + self._payload = None + self._payload_size = 0 # byte size of current payload + + StatefulTensor.GST_MGR.register_new_instance() + + if self._state == TensorState.FREE: + # when the state is free, payload should be None + assert maybe_tensor is None, f"payload has to None if state is {self._state}" + else: + # otherwise, payload should not be None + assert maybe_tensor is not None, f"payload can't be None if state is {self._state}" + self._payload = maybe_tensor + self._payload_size = sizeof_tensor(maybe_tensor) + self.__trans_state_update(TensorState.FREE, state) + + def data_ptr(self): + if self._payload is None: + return 0 # if a tensor has no storage, 0 should be returned + return self._payload.data_ptr() + + def set_null(self) -> None: + # notice that free stateful tensor do not need to become null again + if self.state != TensorState.FREE: + self.__trans_state_update(self.state, TensorState.FREE) + self.__release() + + def is_null(self) -> bool: + if self.state == TensorState.FREE: + # check sanity here + assert self.payload is None + return True + return False + + def trans_state(self, state: TensorState) -> None: + if self.state == TensorState.FREE: + # free stateful tensor can't change state + assert state == TensorState.FREE, "Free stateful tensor can't change to other states" + return + + self.__trans_state_update(self.state, state) + + if state == TensorState.FREE: + self.__release() + else: + self._state = state + + def move_to(self, device: Union[torch.device, int]): + assert self.state is not TensorState.FREE, "Can't move free stateful tensor" + + if not isinstance(device, torch.device): + to_device = torch.device('cuda', device) + else: + to_device = device + + from_device_type = self.device.type + if from_device_type == to_device.type: + # from device == to device + return + + # update manager's information + self.__trans_device_update(from_device_type, to_device.type) + self.payload.data = self.payload.data.to(to_device) + + def payload_copy(self, tensor) -> None: + self._payload.view(-1).copy_(tensor.view(-1)) + + def payload_reset(self, tensor) -> None: + + assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" + + if self.payload is not None: + # release old payload + self.__trans_state_update(self.state, TensorState.FREE) + else: + # otherwise, set the state to HOLD for new payload + self._state = TensorState.HOLD + del self._payload + + self._payload = tensor + self._payload_size = sizeof_tensor(tensor) + # record new payload + self.__trans_state_update(TensorState.FREE, self.state) + + def payload_relay(self, rhs): + # relay the payload of rhs to current stateful tensor + # can't support null relay right now + assert not rhs.is_null() + + # now this function only support stateful tensor that has zero-length payload + # because it doesn't require memory manager updating + # you can extend this function by yourself + assert self.payload_size == 0 + + self._payload = rhs.payload + self._payload_size = rhs.payload_size + self._state = TensorState.HOLD + self.__trans_state_update(rhs.state, TensorState.HOLD) + + rhs.__release() + + @property + def payload(self) -> Optional[torch.Tensor]: + return self._payload + + @property + def payload_size(self) -> int: + return self._payload_size + + @property + def state(self) -> TensorState: + return self._state + + @property + def device(self) -> torch.device: + return self._payload.device + + @property + def dtype(self) -> torch.dtype: + return self._payload.dtype + + @property + def shape(self): + return self._payload.shape + + def to(self, device: torch.device): + raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor") + + def to_(self, device: torch.device): + raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor") + + def __release(self): + # release current payload + # shouldn't be visible to users + self._state = TensorState.FREE + self._payload = None + self._payload_size = 0 + + def __trans_state_update(self, from_state: TensorState, to_state: TensorState): + """Update global manager when changing the state of a tensor + """ + manager = StatefulTensor.GST_MGR + size = self.payload_size + device_type = self.device.type + + if from_state != TensorState.FREE: + manager.state_mem[device_type][from_state] -= size + else: + # when from_state is FREE, the tensor is new to manager + # we should add its memory + manager.total_mem[device_type] += size + + if to_state != TensorState.FREE: + manager.state_mem[device_type][to_state] += size + else: + # when to_state is FREE, the tensor will be deleted soon + # we should sub its memory + manager.total_mem[device_type] -= size + + def __trans_device_update(self, from_type: str, to_type: str): + """Update global manager when changing the device of a tensor + """ + manager = StatefulTensor.GST_MGR + size = self.payload_size + state = self.state + + # update aggregated information + manager.total_mem[from_type] -= size + manager.total_mem[to_type] += size + + # update the information of each state + manager.state_mem[from_type][state] -= size + manager.state_mem[to_type][state] += size + + def __del__(self): + self.set_null() + StatefulTensor.GST_MGR.delete_instance() + del self diff --git a/colossalai/gemini/stateful_tensor_mgr.py b/colossalai/gemini/stateful_tensor_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..c300f9bffc8906ee42187f0beb32db91fc586d52 --- /dev/null +++ b/colossalai/gemini/stateful_tensor_mgr.py @@ -0,0 +1,100 @@ +import functools +import torch +import types +from colossalai.utils.cuda import get_current_device +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy +from typing import List +from colossalai.logging import get_dist_logger +from time import time + + +class StatefulTensorMgr(object): + """ + Stateful Tensor Manager, inspired from PatrickStar + + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None: + self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy + self._stateful_tensor_list: List[StatefulTensor] = [] + + self._compute_list: List[StatefulTensor] = [] + self._compute_idx: int = -1 + + self._cpu_gpu_move_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._warmup = True + + def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None: + assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice" + self._stateful_tensor_list = tensor_list + for t in self._stateful_tensor_list: + assert isinstance(t, StatefulTensor) + t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t) + + def start_iter(self): + pass + + def finish_iter(self): + """This function must be called when each iteration finishes + """ + self._warmup = False + self._compute_idx = -1 + self._cpu_gpu_move_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + + def adjust_layout(self) -> None: + """ Adjust the layout of statefuil tensor according to the information provided + by mem_stats_collector, which should belongs to a Sharded Model. + """ + # find stateful tensor in state COMPUTE + cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] + start = time() + move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) + self._layout_time += time() - start + vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx) + self._cpu_gpu_move_volume += vol + self._evict_time += evict_time + # move COMPUTE tensors to CUDA + self._cpu_gpu_move_volume += cuda_demand + for t in move_to_cuda_tensor_list: + colo_model_data_tensor_move_inline(t, get_current_device()) + + @property + def cpu_gpu_move_volume(self): + return self._cpu_gpu_move_volume + + def _trans_state(self, trans_state_func, stateful_tensor, state): + trans_state_func(state) + if state == TensorState.COMPUTE: + self._compute_idx += 1 + if self._warmup: + self._compute_list.append(stateful_tensor) + + @functools.lru_cache(maxsize=None) + def _get_layout_info(self, compute_idx: int, warmup: bool): + move_to_cuda_tensor_list = [] + hold_cuda_tensor_list = [] + for tensor in self._stateful_tensor_list: + if tensor.state == TensorState.FREE: + continue + + if tensor.device.type == 'cuda': + if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: + hold_cuda_tensor_list.append(tensor) + elif tensor.device.type == 'cpu': + if tensor.state == TensorState.COMPUTE: + move_to_cuda_tensor_list.append(tensor) + else: + raise RuntimeError + return move_to_cuda_tensor_list, hold_cuda_tensor_list diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcfb385667c018a66b91608ece59bf3baa11724 --- /dev/null +++ b/colossalai/gemini/tensor_placement_policy.py @@ -0,0 +1,138 @@ +from abc import ABC, abstractmethod +from time import time +from typing import List, Optional +import torch +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity + +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.gemini.memory_tracer import MemStatsCollector +from typing import Type +import functools + + +class TensorPlacementPolicy(ABC): + + def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + self.device: Optional[torch.device] = device + self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector + + @abstractmethod + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None: + raise NotImplementedError + + +class CPUTensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: + volume = 0 + for t in hold_cuda_tensor_list: + colo_model_data_tensor_move_inline(t, self.device) + volume += t.payload.numel() * t.payload.element_size() + return volume, 0 + + +class CUDATensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: + return 0, 0 + + +class AutoTensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + super().__init__(None, mem_stats_collector=mem_stats_collector) + # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase + # TODO(ver217): make these args configurable + self._warmup_non_model_data_ratio: float = 0.8 + self._steady_cuda_cap_ratio: float = 0.9 + + def evict_tensors(self, + hold_cuda_tensor_list: List[StatefulTensor], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: List[StatefulTensor] = [], + compute_idx: int = 0, + **kwargs) -> int: + """ + Evict tensors from CUDA device. + + Args: + hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like + cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. + warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. + compute_list (List[StatefulTensor], optional): TODO. Defaults to []. + compute_idx (int, optional): the idx of computing device. Defaults to 0. + + Raises: + RuntimeError: + + Returns: + int: the volume of memory that is evicted + """ + start = time() + cuda_capacity = colo_device_memory_capacity(get_current_device()) + used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda'] + if warmup: + # We designate a part of CUDA memory for model data in warmup iterations. + max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio + else: + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + cuda_capacity *= self._steady_cuda_cap_ratio + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + freed_cuda_model_data = 0 + end = time() + if avail_cuda_model_data < cuda_demand: + # Move cuda_demand - avail_cuda_model_data volume of tensors + # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_tensor_list = hold_cuda_tensor_list + if not warmup: + to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx, + tuple(compute_list)) + # print(self._sort_hold_cuda_tensors.cache_info()) + end = time() + for t in to_free_tensor_list: + if freed_cuda_model_data >= to_free_cuda_model_data: + break + freed_cuda_model_data += t.payload_size + colo_model_data_tensor_move_inline(t, torch.device('cpu')) + if freed_cuda_model_data < to_free_cuda_model_data: + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) + return freed_cuda_model_data, end - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors} + for i in range(len(compute_list) - 1, compute_idx, -1): + if compute_list[i] in next_compute_idx: + next_compute_idx[compute_list[i]] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + +class TensorPlacementPolicyFactory: + + @staticmethod + def create(policy_name: str) -> Type[TensorPlacementPolicy]: + if policy_name == 'cpu': + return CPUTensorPlacementPolicy + elif policy_name == 'cuda': + return CUDATensorPlacementPolicy + elif policy_name == 'auto': + return AutoTensorPlacementPolicy + else: + raise TypeError(f"Unknown tensor placement policy {policy_name}") diff --git a/colossalai/gemini/tensor_utils.py b/colossalai/gemini/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc159f9954a32f6125100a5e036bcf1560554f4 --- /dev/null +++ b/colossalai/gemini/tensor_utils.py @@ -0,0 +1,118 @@ +import torch +from colossalai.gemini.stateful_tensor import StatefulTensor +from typing import Union, Tuple + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]: + if isinstance(tensor, StatefulTensor): + t = tensor.payload + elif isinstance(tensor, torch.Tensor): + t = tensor + else: + return 0, 0 + + cuda_use, cpu_use = 0, 0 + + mem_use = t.storage().size() * t.element_size() + if t.device.type == 'cuda': + cuda_use += mem_use + elif t.device.type == 'cpu': + cpu_use += mem_use + + return cuda_use, cpu_use + + +def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, + torch.Tensor]) -> None: + """ + A colossal API for model data tensor move. + The src and target tensors could be resident on both CPU and GPU. + + NOTE() The source tensor payload will be removed after this function. + + The function will record the communication volume between CPU and GPU. + Args: + src_t (Union[StatefulTensor, torch.Tensor]): source tensor + tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor + """ + if isinstance(src_t, StatefulTensor): + src_t_payload = src_t.payload + else: + src_t_payload = src_t.data + src_dev = src_t_payload.device + + if isinstance(tgt_t, StatefulTensor): + tgt_t_payload = tgt_t.payload + else: + tgt_t_payload = tgt_t.data + + tgt_t_payload.copy_(src_t_payload) + + # remove payload of src_t + if isinstance(src_t, StatefulTensor): + src_t.set_null() + else: + src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype) + + +def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, + int]) -> None: + """ + move a tensor to the target_device + Args: + t (Union[StatefulTensor, torch.Tensor]): the tensor be moved + target_device: a traget device, if type is int, it the index of cuda card. + """ + if not isinstance(target_device, torch.device): + target_device = torch.device(f'cuda:{target_device}') + + if isinstance(t, torch.Tensor): + t.data = t.data.to(target_device) + elif isinstance(t, StatefulTensor): + t.move_to(target_device) + else: + raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}') + + +def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: + """colo_model_data_move_to_cpu + move a model data tensor from gpu to cpu + Args: + t (Union[StatefulTensor, torch.Tensor]): _description_ + """ + # TODO() optimize the tensor moving with non-blocking + if isinstance(t, torch.Tensor): + t.data = t.data.cpu() + elif isinstance(t, StatefulTensor): + t.move_to(torch.device('cpu')) + else: + raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}') + + +def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: + """ + Clone a model data tensor + Args: + t (Union[StatefulTensor, torch.Tensor]): a model data tensor + target_device (torch.device): the target device + Returns: + torch.Tensor: a cloned torch tensor + """ + # TODO() rename this function + colo_model_data_tensor_move_inline(t, target_device) + t_payload = t.payload if isinstance(t, StatefulTensor) else t + return t_payload diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..61b31965e2e63d2119bfadba0d49478537c31fa7 --- /dev/null +++ b/colossalai/global_variables.py @@ -0,0 +1,56 @@ +from typing import Optional + + +class TensorParallelEnv(object): + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = object.__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self, *args, **kwargs): + self.load(*args, **kwargs) + + def load(self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None): + self.mode = mode + self.vocab_parallel = vocab_parallel + self.parallel_input_1d = parallel_input_1d + self.summa_dim = summa_dim + self.tesseract_dim = tesseract_dim + self.tesseract_dep = tesseract_dep + self.depth_3d = depth_3d + self.input_group_3d = input_group_3d + self.weight_group_3d = weight_group_3d + self.output_group_3d = output_group_3d + self.input_x_weight_group_3d = input_x_weight_group_3d + self.output_x_weight_group_3d = output_x_weight_group_3d + + def save(self): + return dict(mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d) + + +tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/initialize.py b/colossalai/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..e907efddee693ed0c3b2bc100b5534fea52137c8 --- /dev/null +++ b/colossalai/initialize.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import argparse +import os +import pprint +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from colossalai.core import global_context as gpc +from colossalai.context.moe_context import MOE_CONTEXT + +from colossalai.logging import get_dist_logger + +from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape +from colossalai.engine import Engine +from colossalai.gemini.ophooks import BaseOpHook + +from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) +from colossalai.utils.moe import sync_moe_model_param + +from colossalai.amp import AMP_TYPE, convert_to_amp +from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.builder.builder import build_gradient_handler +from colossalai.context import Config, ConfigException, ParallelMode +from colossalai.engine.gradient_accumulation import accumulate_gradient + +from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer + +from colossalai.zero import convert_to_zero_v2 +from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 + + +def get_default_parser(): + """Reads user command line and uses an argument parser to parse the input arguments. + Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. + + Returns: + Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, help='path to the config file') + parser.add_argument('--host', type=str, help='the master address for distributed training') + parser.add_argument('--port', type=int, help='the master port for distributed training') + parser.add_argument('--world_size', type=int, help='world size for distributed training') + parser.add_argument('--rank', type=int, help='rank for the default process group') + parser.add_argument('--local_rank', type=int, help='local rank on the node') + parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') + return parser + + +def launch(config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = 'nccl', + local_rank: int = None, + seed: int = 1024, + verbose: bool = True): + """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input + arguments are not given. Then initialize and set distributed environment by calling global_context's functions. + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + rank (int): Rank for the default process group + world_size (int): World size of the default process group + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + local_rank (int, optional): + Rank for the process on the node and is used to set the default CUDA device, + defaults to None. If local_rank = None, the default device ordinal will be calculated automatically. + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + + Raises: + Exception: Raise exception when config type is wrong + """ + gpc.verbose = verbose + + # set config + assert isinstance(config, (Config, str, Path, dict)), \ + f'expected argument config to be Config, str or Path, but got {type(config)}' + if not isinstance(config, Config) and isinstance(config, dict): + config = Config(config) + if isinstance(config, (str, Path)): + config = Config.from_file(config) + gpc.load_config(config) + + # init default process group + gpc.init_global_dist(rank, world_size, backend, host, port) + + # init process groups for different parallel modes from config + gpc.init_parallel_groups() + + # set cuda device + if torch.cuda.is_available(): + # if local rank is not given, calculate automatically + gpc.set_device(local_rank) + + # set the number of processes running on the same node + gpc.detect_num_processes_on_current_node() + + gpc.set_seed(seed) + + if verbose: + logger = get_dist_logger() + logger.info( + f'Distributed environment is initialized, ' + f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' + f'tensor parallel size: {gpc.tensor_parallel_size}', + ranks=[0]) + + +def launch_from_slurm(config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables + set by SLURM + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['SLURM_PROCID']) + world_size = int(os.environ['SLURM_NPROCS']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" + ) + + launch(config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def launch_from_openmpi(config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables + set by OpenMPI + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" + ) + + launch(config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def launch_from_torch(config: Union[str, Path, Config, Dict], + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size + from the environment variables set by PyTorch + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + + launch(config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def initialize(model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: + """Core function to wrap the essential training components with our functionality based on the config which is + loaded into gpc.config. + + Args: + model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model. + optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): + Your optimizer instance. + criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. + train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. + lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. + verbose (bool, optional): Whether to print logs. + + Returns: + Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): + A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` + where only ``engine`` could not be None. + """ + # get logger + logger = get_dist_logger() + gpc.verbose = verbose + + # get config from gpc + config = gpc.config + + # print config + if verbose: + logger.info( + f"\n========== Your Config ========\n" + f"{pprint.pformat(gpc.config)}\n" + f"================================\n", + ranks=[0]) + + # cudnn + cudnn_benchmark = config.get('cudnn_benchmark', False) + cudnn_deterministic = config.get('cudnn_deterministic', False) + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.backends.cudnn.deterministic = cudnn_deterministic + if verbose: + logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) + + # zero + use_zero = hasattr(gpc.config, 'zero') + if use_zero: + zero_cfg = gpc.config.get('zero', None) + if zero_cfg is not None: + cfg_ = zero_cfg.copy() + else: + cfg_ = {} + optimizer_config = zero_cfg.get('optimizer_config', None) + model_config = zero_cfg.get('model_config', None) + model, optimizer = convert_to_zero_v2(model, + optimizer, + model_config=model_config, + optimizer_config=optimizer_config) + + logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) + else: + if isinstance(model, nn.Module): + # first sync model across dp ranks + model.to(get_current_device()) + elif isinstance(model, Callable): + model = model().to(get_current_device()) + + # optimizer maybe a optimizer_cls + logger.warning("Initializing an non ZeRO model with optimizer class") + if isinstance(optimizer, Callable): + optimizer = optimizer(model.parameters()) + + if not use_zero: + if is_using_sequence(): + sync_model_param(model, ParallelMode.SEQUENCE_DP) + elif MOE_CONTEXT.is_initialized: + sync_moe_model_param(model) + elif is_using_ddp(): + sync_model_param(model, ParallelMode.DATA) + else: + logger.warning( + "The parameters of models is not automatically synchronized.\n" + "Please make sure that all parameters are the same in data parallel group.", + ranks=[0]) + + # check amp and zero + fp16_cfg = gpc.config.get('fp16', None) + + if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: + raise ConfigException( + "It is not allowed to set fp16 and zero configuration in your config file at the same time") + + # clip grad norm + clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) + + # initialize amp + amp_mode = None + if fp16_cfg is not None and fp16_cfg.mode is not None: + cfg_ = fp16_cfg.copy() + amp_mode = cfg_.pop('mode') + if is_using_pp(): + assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' + if amp_mode == AMP_TYPE.NAIVE: + cfg_['clip_grad_norm'] = clip_grad_norm + model, optimizer, criterion = convert_to_amp(model=model, + optimizer=optimizer, + criterion=criterion, + mode=amp_mode, + amp_config=cfg_) + + # get torch ddp config + torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + + # gradient handler + gradient_handler_cfg = gpc.config.get('gradient_handler', None) + if gradient_handler_cfg is None: + # if gradient handler is not specified in the configuration file, + # check in the following order + # 1. if optimizer is ZERO, then use zero grad handler + # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp + # 3. if using pipeline and dp size larger than 1, use data parallel grad handler + if isinstance(optimizer, ShardedOptimizerV2): + gradient_handler_cfg = [dict(type='ZeROGradientHandler')] + if verbose: + logger.info( + "Training with zero is detected, ZeROGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + elif is_using_ddp() and MOE_CONTEXT.is_initialized: + gradient_handler_cfg = [dict(type='MoeGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + elif is_using_sequence(): + model = DDP(model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) + if verbose: + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', + ranks=[0]) + elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: + model = DDP(model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) + if verbose: + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) + elif is_using_ddp(): + gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected when using pipeline parallel, " + "DataParallelGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + # add pipeline parallel gradient handler, if pipeline shared module is detected + for param in model.parameters(): + if getattr(param, 'pipeline_shared_module_pg', None) is not None: + if gradient_handler_cfg is None: + gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] + else: + gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) + if verbose: + logger.info( + "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + break + else: + if not isinstance(gradient_handler_cfg, list): + raise ConfigException( + f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" + ) + + # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time + # to avoid duplicated buffer synchronization + if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): + model.module.sync_buffer = False + + # initialize schedule for engine + if is_using_pp(): + tensor_shape = get_tensor_shape() + use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') + if gpc.is_initialized(ParallelMode.PARALLEL_1D): + scatter_gather = True + else: + scatter_gather = False + if use_interleaved: + if isinstance(model, nn.Sequential): + model = nn.ModuleList([model]) + schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) + else: + schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) + else: + schedule = NonPipelineSchedule() + + if gradient_handler_cfg is None: + gradient_handlers = None + if verbose and not isinstance(model, DDP): + logger.warning( + "No PyTorch DDP or gradient handler is set up, please make sure you do not need " + "to all-reduce the gradients after a training step.", + ranks=[0]) + else: + gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] + + # check if optimizer is ColossalaiOptimizer + if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)): + optimizer = ColossalaiOptimizer(optim=optimizer) + + # gradient accumulation + grad_accum_size = gpc.config.get('gradient_accumulation', None) + if grad_accum_size is not None: + optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + accumulate_size=grad_accum_size, + gradient_handlers=gradient_handlers, + lr_scheduler=lr_scheduler) + engine = Engine(model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks, + schedule=schedule) + + return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42c95729ac4145a2c84a5632216d77f28ecf7143 --- /dev/null +++ b/colossalai/kernel/__init__.py @@ -0,0 +1,3 @@ +from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention + +__all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"] diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f857ff5d9f1f6722a1bbb3153d0f319bd423586 --- /dev/null +++ b/colossalai/kernel/cuda_native/__init__.py @@ -0,0 +1,3 @@ +from .layer_norm import MixedFusedLayerNorm as LayerNorm +from .multihead_attention import MultiHeadAttention +from .scaled_softmax import FusedScaleMaskSoftmax diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94f132521771bb18638d3a7edf03f7e4e14dcc27 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp @@ -0,0 +1,49 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu +#include + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float scale); + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale); + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int mode, + const int bias_correction, const float weight_decay, + const float div_scale); + +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int bias_correction, + const float weight_decay, const int grad_averaging, + const int mode, at::Tensor global_grad_norm, + const float max_grad_norm, + at::optional use_nvlamb_python); + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("multi_tensor_scale", &multi_tensor_scale_cuda, + "Fused overflow check + scale for a list of contiguous tensors"); + m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, + "Fused SGD optimizer for list of contiguous tensors"); + m.def("multi_tensor_adam", &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); + m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, + "Computes and apply update for LAMB optimizer"); + m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, + "Computes L2 norm for a list of contiguous tensors"); +} diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..00066dc95475296168c799904dc595ed435d2b0a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -0,0 +1,10 @@ +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ab250218da38f9ded00766d6546d73b918699fa --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -0,0 +1,459 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#include "cpu_adam.h" + +#include +#include +#include + +#include +#include +#include +#include + +// C++ interface + +void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + float loss_scale) { + size_t rounded_size = 0; + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + + __half *params_cast_h = NULL; + __half *grads_cast_h = NULL; + + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half *>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half *>(grads); + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data grad_4; + if (grad_half_precision) { + grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); + } else { + grad_4.data = SIMD_LOAD(grads + i); + } + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); + } + AVX_Data momentum_4; + momentum_4.data = SIMD_LOAD(_exp_avg + i); + + AVX_Data variance_4; + variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + + AVX_Data param_4; + if (param_half_precision) { + param_4.data = SIMD_LOAD_HALF(params_cast_h + i); + } else { + param_4.data = SIMD_LOAD(_params + i); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); + } + momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); + momentum_4.data = + SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); + variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data); + grad_4.data = SIMD_MUL(grad_4.data, grad_4.data); + variance_4.data = + SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data); + grad_4.data = SIMD_SQRT(variance_4.data); + grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); + grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4.data = + SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data); + } + param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); + + if (param_half_precision) { + SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); + } else { + SIMD_STORE(_params + i, param_4.data); + } + SIMD_STORE(_exp_avg + i, momentum_4.data); + SIMD_STORE(_exp_avg_sq + i, variance_4.data); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k]; + if (loss_scale > 0) { + grad /= loss_scale; + } + float param = + param_half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { + grad = param * _weight_decay + grad; + } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { + param += w_decay * param; + } + param = grad * step_size + param; + + if (param_half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + } + } +} + +void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + float loss_scale) { + size_t rounded_size = 0; + + __half *params_cast_h = NULL; + __half *grads_cast_h = NULL; + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half *>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half *>(grads); + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + AVX_Data grad_4[4]; + AVX_Data momentum_4[4]; + AVX_Data variance_4[4]; + AVX_Data param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + if (grad_half_precision) { + grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); + } else { + grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); + } + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + + if (param_half_precision) { + param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); + } else { + param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = + SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = + SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = + SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + if (param_half_precision) { + SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), + param_4[j].data); + } else { + SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); + } + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + } + } + } +#endif + if (_param_size > rounded_size) + Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size) + : _params + rounded_size), + (grad_half_precision ? (float *)(grads_cast_h + rounded_size) + : grads + rounded_size), + (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), param_half_precision, + grad_half_precision, loss_scale); +} + +void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + float loss_scale) { + size_t rounded_size = 0; + __half *params_cast_h = NULL; + __half *grads_cast_h = NULL; + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half *>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half *>(grads); + } +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + AVX_Data grad_4[8]; + AVX_Data momentum_4[8]; + AVX_Data variance_4[8]; + AVX_Data param_4[8]; +#pragma unroll 8 + for (int j = 0; j < 8; j++) { + if (grad_half_precision) { + grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); + } else { + grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); + } + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + + if (param_half_precision) { + param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); + } else { + param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = + SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = + SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = + SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + + if (param_half_precision) { + SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), + param_4[j].data); + } else { + SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); + } + + SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); + SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + } + } + } +#endif + if (_param_size > rounded_size) + Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size) + : _params + rounded_size), + (grad_half_precision ? (float *)(grads_cast_h + rounded_size) + : grads + rounded_size), + (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), param_half_precision, + grad_half_precision, loss_scale); +} + +void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, + bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float *params_ptr = (float *)params_c.data_ptr(); + float *grads_ptr = (float *)grads_c.data_ptr(); + float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); + float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); + + this->IncrementStep(step, beta1, beta2); + this->update_state(lr, epsilon, weight_decay, bias_correction); + this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, + params_c.numel(), (params.options().dtype() == at::kHalf), + (grads.options().dtype() == at::kHalf), loss_scale); +} + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "CPUAdamOptimizer") + .def(py::init()) + .def("step", &Adam_Optimizer::step); +} diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h new file mode 100644 index 0000000000000000000000000000000000000000..2df191e8e514c0d2b7105a7cd91cf1a0ae7c63fd --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -0,0 +1,164 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + +#if defined(__AVX512__) +#define SIMD_WIDTH 16 +#define INTV __m256i +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_LOAD_HALF(x) \ + _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_store_ps( \ + x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) + +#elif defined(__AVX256__) or defined(__AVX2__) +#define SIMD_WIDTH 8 +#define INTV __m128i +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm_store_ps( \ + x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) + +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) or defined(__AVX2__) + __m256 data; +#endif + // float data_f[16]; +}; + +#endif + +#define STEP(SPAN) \ + void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ + float *_exp_avg_sq, size_t _param_size, \ + bool param_half_precision = false, \ + bool grad_half_precision = false, float loss_scale = -1); + +class Adam_Optimizer { + public: + Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) {} + ~Adam_Optimizer() {} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + + void step(size_t step, float lr, float beta1, float beta2, float epsilon, + float weight_decay, bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale); + + private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu new file mode 100644 index 0000000000000000000000000000000000000000..58d26235a9cc6954e9822119f215b9745b0a1684 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu @@ -0,0 +1,191 @@ +#include "block_reduce.h" +#include "cuda_util.h" +#include "kernels.h" +#include "ls_cub.cuh" + +ls::cub::CachingDeviceAllocator g_allocator(true); + +template +__global__ void ls_cross_entropy_fw_kernel( + const T *__restrict__ inputs, const int *__restrict__ targets, + float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, + const int padding_idx, const float epsilon, const int vocab_size) { + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * max_input, sum_exp_logit */ + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float max_input[1] = {REDUCE_FLOAT_INF_NEG}; + float sum_logits[2] = {0.f, 0.f}; // logit and logit exp + int target_tid = targets[blockIdx.x]; + + if (target_tid == padding_idx) { + if (threadIdx.x == 0) { + nll_loss_outputs[blockIdx.x] = 0.f; + outputs[blockIdx.x] = 0.f; + } + return; + } + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); + } + blockReduce(max_input); + __shared__ float s_max_input; + if (threadIdx.x == 0) { + s_max_input = max_input[0]; + } + __syncthreads(); + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float logit = static_cast(inputs[i]) - s_max_input; + sum_logits[0] += logit; + sum_logits[1] += expf(logit); + } + + blockReduce(sum_logits); + __shared__ float s_sum_logit; + __shared__ float s_sum_exp; + if (threadIdx.x == 0) { + s_sum_logit = sum_logits[0]; + s_sum_exp = sum_logits[1]; + } + __syncthreads(); + + float eps_i = epsilon / (vocab_size - 1); + if (threadIdx.x == 0) { + // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) + float nll_loss = logf(s_sum_exp) - + static_cast(inputs[block_start + target_tid]) + + s_max_input; + nll_loss_outputs[blockIdx.x] = nll_loss; + float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; + outputs[blockIdx.x] = + (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; + } +} + +template +__global__ void ls_cross_entropy_bw_kernel( + const float *__restrict__ grad_outputs, const T *__restrict__ inputs, + const int *__restrict__ targets, T *__restrict__ grad_inputs, + const int padding_idx, const float epsilon, const int vocab_size) { + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * max_input, sum_exp_logit */ + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float max_input[1] = {REDUCE_FLOAT_INF_NEG}; + float sum_logits[1] = {0.f}; + const float grad_out = static_cast(grad_outputs[0]); + int target_tid = targets[blockIdx.x]; + + if (target_tid == padding_idx) { + for (int i = left_idx; i < right_idx; i += blockDim.x) { + grad_inputs[i] = 0.f; + } + return; + } + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); + } + blockReduce(max_input); + __shared__ float s_max_input; + if (threadIdx.x == 0) { + s_max_input = max_input[0]; + } + __syncthreads(); + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float logit = static_cast(inputs[i]) - s_max_input; + sum_logits[0] += expf(logit); + } + + blockReduce(sum_logits); + __shared__ float s_sum_exp; + if (threadIdx.x == 0) { + s_sum_exp = sum_logits[0]; + } + __syncthreads(); + + float eps_i = epsilon / (vocab_size - 1); + float nll_weight = 1.0 - epsilon - eps_i; + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; + float grad = 0; + grad += (vocab_size * prob - 1) * eps_i; + grad += prob * nll_weight; + if ((i - block_start) == target_tid) { + grad -= nll_weight; + } + grad_inputs[i] = grad_out * grad; + } +} + +template +void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, + float *outputs_ptr, float *nll_loss_ptr, + float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, + const int seq_len, const int vocab_size, + cudaStream_t stream) { + int grid_dim = batch_size * seq_len; + float *nll_loss_buffer = loss_buffer + grid_dim; + ls_cross_entropy_fw_kernel<<>>( + inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, + epsilon, vocab_size); + + int num_items = grid_dim; + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + loss_buffer, outputs_ptr, + num_items, stream)); + CHECK_GPU_ERROR( + g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + loss_buffer, outputs_ptr, + num_items, stream)); + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + nll_loss_buffer, nll_loss_ptr, + num_items, stream)); + CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); +} + +template void launch_cross_entropy_fw( + const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr, float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template void launch_cross_entropy_fw<__half>( + const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr, float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template +void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr, + const int padding_idx, const float epsilon, + const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream) { + int grid_dim = batch_size * seq_len; + ls_cross_entropy_bw_kernel<<>>( + grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, + epsilon, vocab_size); +} + +template void launch_cross_entropy_bw( + const float *grad_outputs_ptr, const float *inputs_ptr, + const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template void launch_cross_entropy_bw<__half>( + const float *grad_outputs_ptr, const __half *inputs_ptr, + const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu new file mode 100644 index 0000000000000000000000000000000000000000..6c49280ff2734a2ec08a4f628424320a62b3a7e7 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu @@ -0,0 +1,171 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#include "cublas_wrappers.h" + +#ifdef COLOSSAL_HIP +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, rocblas_gemm_algo algo) { + cublasStatus_t status = + rocblas_gemm_ex(handle, transa, transb, m, n, k, (const void *)alpha, + (const void *)A, rocblas_datatype_f32_r, (transa == rocblas_operation_none) ? m : k, + (const void *)B, rocblas_datatype_f32_r, (transb == rocblas_operation_none) ? k : n, + (const void *)beta, C, rocblas_datatype_f32_r, m, C, rocblas_datatype_f32_r, m, rocblas_datatype_f32_r, algo, 0, 0); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, rocblas_gemm_algo algo) { + cublasStatus_t status = rocblas_gemm_ex( + handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, + rocblas_datatype_f16_r, (transa == rocblas_operation_none) ? m : k, (const void *)B, rocblas_datatype_f16_r, + (transb == rocblas_operation_none) ? k : n, (const void *)beta, (void *)C, + rocblas_datatype_f16_r, m, (void *)C, rocblas_datatype_f16_r, m, rocblas_datatype_f32_r, algo, 0, 0); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, rocblas_gemm_algo algo) { + cublasStatus_t status = rocblas_gemm_strided_batched_ex( + handle, op_A, op_B, m, n, k, alpha, A, rocblas_datatype_f32_r, + (op_A == rocblas_operation_none) ? m : k, stride_A, B, rocblas_datatype_f32_r, + (op_B == rocblas_operation_none) ? k : n, stride_B, beta, C, rocblas_datatype_f32_r, m, stride_C, + C, rocblas_datatype_f16_r, m, stride_C, batch, rocblas_datatype_f32_r, algo, 0, 0); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " + "error: %d) \n", + batch, m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, rocblas_gemm_algo algo) { + cublasStatus_t status = rocblas_gemm_strided_batched_ex( + handle, op_A, op_B, m, n, k, alpha, A, rocblas_datatype_f16_r, + (op_A == rocblas_operation_none) ? m : k, stride_A, B, rocblas_datatype_f16_r, + (op_B == rocblas_operation_none) ? k : n, stride_B, beta, C, rocblas_datatype_f16_r, m, stride_C, + C, rocblas_datatype_f16_r, m, stride_C, batch, rocblas_datatype_f32_r, algo, 0, 0); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + + return 0; +} +#else +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, cublasGemmAlgo_t algo) { + cublasStatus_t status = + cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, + (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, + (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, + (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmEx( + handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, + CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, + (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, + CUDA_R_16F, m, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmStridedBatchedEx( + handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, + (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, + (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, + batch, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " + "error: %d) \n", + batch, m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmStridedBatchedEx( + handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, + (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, + (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, + batch, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + + return 0; +} +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..5a45a7647a9864c088a4897af4129b5d3f9375d7 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -0,0 +1,176 @@ +#include +#include + +#ifdef COLOSSAL_HIP +#include +#include "hip_util.h" +#else +#include "cuda_util.h" +#endif + +/* GPU function guard */ +std::string _cudaGetErrorString(cudaError_t error) { + return cudaGetErrorString(error); +} + +std::string _cudaGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; +#ifndef COLOSSAL_HIP + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; +#endif + } + return "CUBLAS_UNKNOW"; +} + +template +void check_gpu_error(T result, char const *const func, const char *const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + + std::to_string(line) + + "): " + (_cudaGetErrorString(result)) + "\n"); + } +} + +template void check_gpu_error(cudaError_t result, + char const *const func, + const char *const file, + int const line); +template void check_gpu_error(cublasStatus_t result, + char const *const func, + const char *const file, + int const line); + +template +void print_vec(const T *outv, std::string outn, int num_output_ele) { + std::cout << outn << ": "; + std::vector hout(num_output_ele, (T)0); + cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), + cudaMemcpyDeviceToHost); + for (int i = 0; i < num_output_ele; i++) { + std::cout << hout[i] << ", "; + } + std::cout << std::endl; +} + +template <> +void print_vec<__half>(const __half *outv, std::string outn, + int num_output_ele) { + std::cout << outn << ": "; + std::vector<__half> hout(num_output_ele, (__half)0.f); + cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), + cudaMemcpyDeviceToHost); + for (int i = 0; i < num_output_ele; i++) { + std::cout << __half2float(hout[i]) << ", "; + } + std::cout << std::endl; +} + +template void print_vec(const float *outv, std::string outn, + int num_output_ele); + +template void print_vec(const int *outv, std::string outn, + int num_output_ele); + +template void print_vec<__half>(const __half *outv, std::string outn, + int num_output_ele); + +template +T *cuda_malloc(size_t ele_num) { + size_t byte_size = ele_num * sizeof(T); + T *pdata = nullptr; + CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); + return pdata; +} + +template float *cuda_malloc(size_t ele_num); + +template __half *cuda_malloc<__half>(size_t ele_num); + +template uint8_t *cuda_malloc(size_t ele_num); + +void cuda_free(void *pdata) { + if (pdata != nullptr) { + cudaFree(pdata); + } +} + +template +struct _isnan { + __device__ bool operator()(T a) const { return isnan(a); } +}; + +template <> +struct _isnan<__half> { + __device__ bool operator()(const __half a) const { return __hisnan(a); } +}; + +template +struct _isinf { + __device__ bool operator()(T a) const { return isinf(a); } +}; + +template <> +struct _isinf<__half> { + __device__ bool operator()(const __half a) const { return __hisinf(a); } +}; + +template +void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, + std::string file, int line, cudaStream_t stream) { + // check_nan_inf = 0 for checking nan + // check_nan_inf = 1 for checking inf + bool res = false; + std::string msg = file + "(" + std::to_string(line) + "): "; + if (check_nan_inf) { + msg += "nan."; + res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, + data_ptr + dsize, _isnan(), false, + thrust::logical_or()); + } else { + msg += "inf."; + res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, + data_ptr + dsize, _isinf(), false, + thrust::logical_or()); + } + + if (res) { + throw std::runtime_error(msg); + } + std::cout << msg << " [check pass]." << std::endl; +} + +template void check_nan_inf(const float *data_ptr, int dsize, + bool check_nan_inf, std::string file, + int line, cudaStream_t stream); + +template void check_nan_inf<__half>(const __half *data_ptr, int dsize, + bool check_nan_inf, std::string file, + int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..ad00ac1dcf0de9437e8df95ed2381cc619851475 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -0,0 +1,1041 @@ +#include +#include + +#include "kernels.h" +#ifdef COLOSSAL_HIP +#include +#endif + +#ifndef COLOSSAL_HIP +#include + +namespace cg = cooperative_groups; +#endif + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { +#ifdef COLOSSAL_HIP + float2 tmp = __half22float2(x); + return __floats2half2_rn(fmaxf(0.f, tmp.x), + fmaxf(0.f, tmp.y)); +#else + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +#endif +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { +#ifdef COLOSSAL_HIP + float2 tmp_x = __half22float2(x_half2); + float2 tmp_grad2 = __half22float2(grad2); + + return __floats2half2_rn(tmp_x.x > 0.0 ? tmp_grad2.x : 0.0, + tmp_x.y > 0.0 ? tmp_grad2.y : 0.0); +#else + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +#endif +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + +#ifndef COLOSSAL_HIP + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); +#endif + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + +#ifdef COLOSSAL_HIP + for (int i = 1; i < 32; i <<= 1) sum += __shfl_down(sum, i); +#else + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); +#endif + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + +#ifndef COLOSSAL_HIP + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); +#endif + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + +#ifdef COLOSSAL_HIP + float2 sum_f2 = __half22float2(sum); + for (int i = 1; i < WARP_SIZE; i <<= 1) sum_f2.x += __shfl_down(sum_f2.x, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) sum_f2.y += __shfl_down(sum_f2.y, i); + sum = __float22half2_rn(sum_f2); +#else + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); +#endif + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + +#ifndef COLOSSAL_HIP + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); +#endif + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + +#ifdef COLOSSAL_HIP + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += __shfl_down(sum, i); +#else + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); +#endif + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..46e5f389cc6d5eb38ec23792d4307c516916b3b2 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -0,0 +1,243 @@ +#include "kernels.h" + +#ifndef COLOSSAL_HIP +#include + +namespace cg = cooperative_groups; +#endif + +#include "kernels.h" + + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + +#ifndef COLOSSAL_HIP + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); +#endif + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile +#ifdef COLOSSAL_HIP + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += __shfl_down(sum, i); +#else + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); +#endif + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..6add8a0ec242c47b87f146c490ee3f784acb950d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h @@ -0,0 +1,387 @@ +/* Copyright 2021 The LightSeq Team + Copyright Tencent/TurboTransformers + This block_reduce_n is adapted from Tencent/TurboTransformers +*/ +#pragma once +#include +#include +#include + +enum class ReduceType { kMax = 0, kSum }; +const unsigned int WARP_REDUCE_MASK = 0xffffffff; +const float REDUCE_FLOAT_INF_NEG = -100000000.f; +const float REDUCE_FLOAT_INF_POS = 100000000.f; +#ifdef COLOSSAL_HIP +const unsigned int WARP_REDUCE_SIZE = 64; +#else +const unsigned int WARP_REDUCE_SIZE = 32; +#endif + +template +__forceinline__ __device__ T warpReduceSum(T val) { + for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) +#ifdef COLOSSAL_HIP + val += __shfl_xor_sync(val, mask, WARP_REDUCE_SIZE); +#else + val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); +#endif + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__forceinline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) shared[wid] = val; + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; + val = warpReduceSum(val); + return val; +} + +template +__inline__ __device__ void blockReduce(float *pval); + +// use template to make code more concise +template +__inline__ __device__ void warpReduce(float *pval); + +// static +template <> +__inline__ __device__ void warpReduce(float *pval) { +#ifdef COLOSSAL_HIP + *pval = max(*pval, __shfl_xor(*pval, 32, WARP_REDUCE_SIZE)); + *pval = max(*pval, __shfl_xor(*pval, 16, WARP_REDUCE_SIZE)); + *pval = max(*pval, __shfl_xor(*pval, 8, WARP_REDUCE_SIZE)); + *pval = max(*pval, __shfl_xor(*pval, 4, WARP_REDUCE_SIZE)); + *pval = max(*pval, __shfl_xor(*pval, 2, WARP_REDUCE_SIZE)); + *pval = max(*pval, __shfl_xor(*pval, 1, WARP_REDUCE_SIZE)); +#else + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32)); +#endif +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp; +#ifdef COLOSSAL_HIP +#define WarpReduceMaxOneStep(a, b) \ + val0_tmp = __shfl_xor(*(pval), a, b); \ + val1_tmp = __shfl_xor(*(pval + 1), a, b); \ + *(pval) = max(val0_tmp, *(pval)); \ + *(pval + 1) = max(val1_tmp, *(pval + 1)); + + WarpReduceMaxOneStep(32, WARP_REDUCE_SIZE); + WarpReduceMaxOneStep(16, WARP_REDUCE_SIZE); + WarpReduceMaxOneStep(8, WARP_REDUCE_SIZE); + WarpReduceMaxOneStep(4, WARP_REDUCE_SIZE); + WarpReduceMaxOneStep(2, WARP_REDUCE_SIZE); + WarpReduceMaxOneStep(1, WARP_REDUCE_SIZE); +#else +#define WarpReduceMaxOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval) = max(val0_tmp, *(pval)); \ + *(pval + 1) = max(val1_tmp, *(pval + 1)); + + WarpReduceMaxOneStep(16, 32); + WarpReduceMaxOneStep(8, 32); + WarpReduceMaxOneStep(4, 32); + WarpReduceMaxOneStep(2, 32); + WarpReduceMaxOneStep(1, 32); +#endif +#undef WarpReduceMaxOneStep +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { +#ifdef COLOSSAL_HIP + *pval += __shfl_xor(*pval, 32, WARP_REDUCE_SIZE); + *pval += __shfl_xor(*pval, 16, WARP_REDUCE_SIZE); + *pval += __shfl_xor(*pval, 8, WARP_REDUCE_SIZE); + *pval += __shfl_xor(*pval, 4, WARP_REDUCE_SIZE); + *pval += __shfl_xor(*pval, 2, WARP_REDUCE_SIZE); + *pval += __shfl_xor(*pval, 1, WARP_REDUCE_SIZE); +#else + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32); +#endif +} + +/* + * Unorll for loop for warpreduce to + * imporve instruction issue efficiency + * ElemX means there are X numbers to be summed + */ + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp; +#ifdef COLOSSAL_HIP +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor(*(pval + 0), a, b); \ + val1_tmp = __shfl_xor(*(pval + 1), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp + + WarpReduceSumOneStep(32, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(16, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(8, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(4, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(2, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(1, WARP_REDUCE_SIZE); +#else +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp + + WarpReduceSumOneStep(16, 32); + WarpReduceSumOneStep(8, 32); + WarpReduceSumOneStep(4, 32); + WarpReduceSumOneStep(2, 32); + WarpReduceSumOneStep(1, 32); +#endif + +#undef WarpReduceSumOneStep +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp, val2_tmp, val3_tmp; +#ifdef COLOSSAL_HIP +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor(*(pval + 0), a, b); \ + val1_tmp = __shfl_xor(*(pval + 1), a, b); \ + val2_tmp = __shfl_xor(*(pval + 2), a, b); \ + val3_tmp = __shfl_xor(*(pval + 3), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp; \ + *(pval + 2) += val2_tmp; \ + *(pval + 3) += val3_tmp + + WarpReduceSumOneStep(32, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(16, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(8, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(4, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(2, WARP_REDUCE_SIZE); + WarpReduceSumOneStep(1, WARP_REDUCE_SIZE); +#else +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ + val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp; \ + *(pval + 2) += val2_tmp; \ + *(pval + 3) += val3_tmp + + WarpReduceSumOneStep(16, 32); + WarpReduceSumOneStep(8, 32); + WarpReduceSumOneStep(4, 32); + WarpReduceSumOneStep(2, 32); + WarpReduceSumOneStep(1, 32); +#endif +#undef WarpReduceSumOneStep +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 2; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 4; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h new file mode 100644 index 0000000000000000000000000000000000000000..f7d75f38cc2b568db74c935ef26cc14afce312ef --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include +#include + +#include "cuda_util.h" + +class Context { + public: + Context() : _stream(nullptr) { + CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); + } + + virtual ~Context() {} + + static Context &Instance() { + static Context _ctx; + return _ctx; + } + + void set_stream(cudaStream_t stream) { + _stream = stream; + CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); + } + + cudaStream_t get_stream() { return _stream; } + + cublasHandle_t get_cublashandle() { return _cublasHandle; } + + private: + cudaStream_t _stream; + cublasHandle_t _cublasHandle; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h new file mode 100644 index 0000000000000000000000000000000000000000..f4e9befc6588563e04c889da6460bd50ffa5aa56 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +#include + +#include "cuda_util.h" + +template +class CrossEntropyLayer { + public: + CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); + + virtual ~CrossEntropyLayer(); + + void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr); + + void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr); + + void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); + + private: + void allocate_mem_buffer() { + // allocate local gpu memory + _loss_buffer = cuda_malloc(_max_batch_tokens * 2); + } + + void free_mem_buffer() { + // free local gpu memory + cuda_free(_loss_buffer); + } + + const int _padding_idx; + const float _epsilon; + const int _max_batch_tokens; + + size_t _batch_size; + size_t _seq_len; + size_t _vocab_size; + + float *_loss_buffer; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h new file mode 100644 index 0000000000000000000000000000000000000000..af4cfaa4ecbb4902211b2781b0a7b5f2206d12e2 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h @@ -0,0 +1,71 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#pragma once + +#include +#include +#include +#include +#include +#ifndef COLOSSAL_HIP +#include +#endif +#include + +#ifdef COLOSSAL_HIP +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, + rocblas_gemm_algo algo = rocblas_gemm_algo_standard); + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, + rocblas_gemm_algo algo = rocblas_gemm_algo_standard); + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, + rocblas_gemm_algo algo = rocblas_gemm_algo_standard); + +int cublas_strided_batched_gemm( + cublasHandle_t handle, int m, int n, int k, const float *alpha, + const float *beta, const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, + int stride_C, int batch, + rocblas_gemm_algo algo = rocblas_gemm_algo_standard); +#else +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); + +int cublas_strided_batched_gemm( + cublasHandle_t handle, int m, int n, int k, const float *alpha, + const float *beta, const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, + int stride_C, int batch, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h new file mode 100644 index 0000000000000000000000000000000000000000..202c181bc9a3c7a8678e694e44131bf813074ee5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#ifndef COLOSSAL_HIP +#include +#endif + +#include +#include +#include +#include +#include +#include + +template +void check_gpu_error(T result, char const *const func, const char *const file, + int const line); + +#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) + +template +void print_vec(const T *outv, std::string outn, int num_output_ele); + +template +T *cuda_malloc(size_t ele_num); + +void cuda_free(void *pdata); + +template +void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, + std::string file, int line, cudaStream_t stream); + +#define CHECK_NAN_INF(ptr, size, stream) \ + check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ + check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..025fbf3f8f15cef0ecb72e827bdade248a58d72f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb0f1800664f5ea55af3a1399c1c5a6acdac9d9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h @@ -0,0 +1,84 @@ +#pragma once + +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#include +#include +#include + +#include + +#include "cublas_wrappers.h" +#include "kernels.h" + +template +class FeedForward { + public: + struct Config { + int outputSize; + int inputSize; + std::array gemm_algos; + Config(int outputs, int inputs) + : outputSize(outputs), + inputSize(inputs), + gemm_algos(std::array({99, 99, 99})) {} + }; + + FeedForward(Config config) : config_(config) {} + + ~FeedForward() {} + + void Forward(int bsz, const T *input_ptr, const T *weights, T *out, + cublasHandle_t &_cublasHandle) { + float alpha = T(1.); + float beta = T(0.); + +#ifdef COLOSSAL_HIP + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, + bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, + out, rocblas_gemm_algo(rocblas_gemm_algo_standard)); +#else + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, + bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, + out, cublasGemmAlgo_t(config_.gemm_algos[0])); +#endif + } + void Backward(int bsz, const T *out_grad, const T *input_ptr, + const T *weights, T *weights_grad, T *bias_grad, + cublasHandle_t &_cublasHandle, cudaStream_t &stream, + T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, + bool compute_bias = true) { + float alpha = (T)1.0, beta = (T)0.0; +#ifdef COLOSSAL_HIP + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, + config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, + weights_grad, rocblas_gemm_algo(rocblas_gemm_algo_standard)); + + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, + bsz, config_.outputSize, &alpha, &beta, weights, out_grad, + inp_grad_out, rocblas_gemm_algo(rocblas_gemm_algo_standard)); +#else + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, + config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, + weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); + + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, + bsz, config_.outputSize, &alpha, &beta, weights, out_grad, + inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); +#endif + if (compute_bias) { + launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, + config_.outputSize, stream); + } + } + + void reset_size(int outputSize, int inputSize) { + config_.outputSize = outputSize; + config_.inputSize = inputSize; + } + + private: + Config config_; +}; \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..1608dfd8b13d46b671627a1c6751a7a1bc52ca66 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -0,0 +1,278 @@ +#pragma once + +#include +#include +#ifdef COLOSSAL_HIP +#include +#else +#include +#endif +#include +#include +#include + +#define MAX_THREADS 1024 +#ifdef COLOSSAL_HIP + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif + +enum class ActivationType { kRelu, kGelu }; + +void launch_curand_init(int total_count, int dim, cudaStream_t stream); + +template +void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int batch_size, + int hidden_dim, cudaStream_t stream); + +template +void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, const T *vars, const T *means, int batch, + int hidden_dim, cudaStream_t stream[2]); + +template +void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, + int from_len, int to_len, bool mask_future, + cudaStream_t stream); + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream); + +// [b, s, h] -> [b, nh, s, ad] +template +void launch_transform_0213(T *output, const T *vals, int batch_size, + int seq_length, int hidden_dim, int nhead, + cudaStream_t stream); + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template +void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, + int dim_0, int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream); + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template +void launch_transform4d_0213(T *output, const T *vals, int batch_size, + int seq_len, int hidden_dim, int nhead, + int trans_count, cudaStream_t stream); + +template +void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, + float ratio, cudaStream_t stream, bool backward = false); + +template +void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, + const T *bias, const T *residual, + int total_count, int dim, float ratio, + cudaStream_t stream); + +template +void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, + const T *bias, int total_count, int dim, + float ratio, cudaStream_t stream); + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template +void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, + cudaStream_t stream); + +void launch_param_update(const float *input, __half *output, int size, + cudaStream_t stream); + +template +void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, + int sz2, int sz1_1, int sz1_2, cudaStream_t stream); + +template +void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, + int seq_len, int hidden_size, cudaStream_t &stream); + +template +void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, + float *outputs_ptr, float *nll_loss_ptr, + float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, + const int seq_len, const int vocab_size, + cudaStream_t stream); + +template +void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr, + const int padding_idx, const float epsilon, + const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template +void launch_lookup_scale_pos_dropout( + T *output, const int *input, const T *embeddings, const T *pos_embeddings, + uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, + int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); + +template +void launch_d_lookup_scale_pos_dropout( + T *grad_embeddings, const T *grad_output, const int *input, + const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, + int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); + +/* Convert 2-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { + return id1 * dim2 + id2; +} + +/* Convert 3-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, + int dim2, int dim3) { + return id1 * dim2 * dim3 + id2 * dim3 + id3; +} + +/* Convert 4-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int +flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { + // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; + int res = id4; + + int ld = dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert 5-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, + int id4, int id5, int dim2, + int dim3, int dim4, + int dim5) { + // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + + // id4*dim5 + dim5; + int res = id5; + + int ld = dim5; + res += id4 * ld; + + ld *= dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert 6-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, + int id4, int id5, int id6, + int dim2, int dim3, int dim4, + int dim5, int dim6) { + // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + + // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; + int res = id6; + + int ld = dim6; + res += id5 * ld; + + ld *= dim5; + res += id4 * ld; + + ld *= dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert vector index to 6-dim tensor index */ +__forceinline__ __host__ __device__ void +decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, + int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { + *id5 = src % dim5; + src /= dim5; + + *id4 = src % dim4; + src /= dim4; + + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 5-dim tensor index */ +__forceinline__ __host__ __device__ void +decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, + int *id1, int *id2, int *id3, int *id4) { + *id4 = src % dim4; + src /= dim4; + + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 4-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, + int dim2, int dim3, + int *id0, int *id1, + int *id2, int *id3) { + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void +decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 2-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, + int *id0, int *id1) { + *id1 = src % dim1; + *id0 = src / dim1; +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4f65e7b54ba19e9520e19d969bebbe4e5d43c266 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh @@ -0,0 +1,12 @@ +// copied from https://github.com/dmlc/dgl/pull/2758 +#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ +#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ + +#define CUB_NS_PREFIX namespace ls { +#define CUB_NS_POSTFIX } +#include "cub/cub.cuh" +#include "cub/util_allocator.cuh" +#undef CUB_NS_POSTFIX +#undef CUB_NS_PREFIX + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h new file mode 100644 index 0000000000000000000000000000000000000000..d88fae2620830563a7788bc3168120495f345b67 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template class Normalize_Layer { +public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + +private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..b917abaf0336a8399ce5900da03c94fb80eb54b5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { config_.nhead = nhead; } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..66729e88f7d4a3c3efa15d90a3cad8f4c03bddf8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h @@ -0,0 +1,122 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#pragma once + +#include +#include +#include + +#include + +#include "cublas_wrappers.h" + +template +class StridedBatchGemm { + public: + struct Config { + int m; + int n; + int k; + float alpha; + float beta; + cublasOperation_t op_A; + cublasOperation_t op_B; + std::array gemm_algos; + + Config(float param_alpha, float param_beta, cublasOperation_t opA, + cublasOperation_t opB) + : alpha(param_alpha), + beta(param_beta), + op_A(opA), + op_B(opB), + gemm_algos(std::array({99, 99, 99})) {} + void SetConfig(int mm, int nn, int kk) { + m = mm; + n = nn; + k = kk; + } + }; + + StridedBatchGemm(const Config &config) : _config(config) {} + + virtual ~StridedBatchGemm() {} + + void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, + cublasHandle_t handle) { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + +#ifdef COLOSSAL_HIP + cublas_strided_batched_gemm( + handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, + _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, + stride_b, stride_c, bsz, rocblas_gemm_algo(rocblas_gemm_algo_standard)); +#else + cublas_strided_batched_gemm( + handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, + _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, + stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); +#endif + } + + void Backward(int bsz, const T *d_output, const T *_buffer_a, + const T *_buffer_b, cublasHandle_t handle, + T *inpGradA = nullptr, T *inpGradB = nullptr) { + int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); + int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); + + int stride_a = mb * _config.n; + int stride_b = _config.n * kb; + int stride_c = _config.m * _config.k; + + // B need to transpose. + cublasOperation_t op_b = + (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. +#ifdef COLOSSAL_HIP + cublas_strided_batched_gemm( + handle, mb, kb, _config.n, &_config.alpha, &_config.beta, + (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), + (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, + CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, + rocblas_gemm_algo(rocblas_gemm_algo_standard)); +#else + cublas_strided_batched_gemm( + handle, mb, kb, _config.n, &_config.alpha, &_config.beta, + (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), + (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, + CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, + cublasGemmAlgo_t(_config.gemm_algos[1])); +#endif + + // A need to transpose. + cublasOperation_t op_a = + (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + stride_a = _config.m * _config.k; + stride_b = _config.m * _config.n; + stride_c = _config.n * _config.k; + + // Calculate d_B. +#ifdef COLOSSAL_HIP + cublas_strided_batched_gemm( + handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, + _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, + stride_c, bsz, rocblas_gemm_algo(rocblas_gemm_algo_standard)); +#else + cublas_strided_batched_gemm( + handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, + _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, + stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); +#endif + } + + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + + private: + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e9b3846fdbbbef0f3698bdb5457afb37d67bc38 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -0,0 +1,1295 @@ +#include "block_reduce.h" +#include "kernels.h" +#ifndef COLOSSAL_HIP +#include + +namespace cg = cooperative_groups; +#endif + +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template __forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void +ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, + const T *inp_or_out, const T *gamma, const T *betta, + const T *vars, const T *means, int rows, int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + +#ifndef COLOSSAL_HIP + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); +#endif + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { +#ifdef COLOSSAL_HIP + s1 += __shfl_down(s1, i); + s2 += __shfl_down(s2, i); +#else + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); +#endif + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { +#ifdef COLOSSAL_HIP + tmp_h2[i] = make_half2(__float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])), + __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1]))); +#else + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); +#endif + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { +#ifdef COLOSSAL_HIP + tmp_h2[i] = make_half2(__float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); +#else + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); +#endif + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { +#ifdef COLOSSAL_HIP + tmp_h2[i] = make_half2(__float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])), + __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1]))); + tmp_h2_1[i] = make_half2(__float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])), + __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1]))); +#else + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); +#endif + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { +#ifdef COLOSSAL_HIP + tmp_h2[i] = make_half2(__float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); + tmp_h2_1[i] = make_half2(__float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); +#else + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); +#endif + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { +#ifdef COLOSSAL_HIP + tmp_h2[i] = make_half2(__float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])), + __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1]))); + tmp_h2_1[i] = make_half2(__float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])), + __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1]))); + tmp_h2_2[i] = make_half2(__float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])), + __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1]))); + tmp_h2_3[i] = make_half2(__float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])), + __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1]))); +#else + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); +#endif + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { +#ifdef COLOSSAL_HIP + tmp_h2[i] = make_half2(__float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); + tmp_h2_1[i] = make_half2(__float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); + tmp_h2_2[i] = make_half2(__float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); + tmp_h2_3[i] = make_half2(__float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt), + __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt)); +#else + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); +#endif + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..f1a6844c6baa8e70b20514e612421340b2b3d0d3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -0,0 +1,392 @@ +#ifndef COLOSSAL_HIP +#include + +namespace cg = cooperative_groups; +#endif +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + +#ifdef COLOSSAL_HIP + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += __shfl_xor(sum, i); +#else + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); +#endif + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..efb1f8a27de79ccf3eee50abb2c64f5039a22345 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -0,0 +1,323 @@ +#ifdef COLOSSAL_HIP +#include +//#include +//#include +//#include +#else +#include +#include +#include +#endif + +#include "kernels.h" + +#ifdef COLOSSAL_HIP +using namespace hipcub; +#else +using namespace cub; +#endif + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void +bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_3, int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void +bias_add_transform_20314<__half>(__half *output, const __half *input, + const __half *bias, int dim_3, int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4690277e63db0a49c23c9274e6553da1e6b04103 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -0,0 +1,141 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include + +#include +#include + +#include "compat.h" + +namespace { + +void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert(input.sizes()[i + idiff] == normalized_shape[i]); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, + at::Tensor beta) { + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input, normalized_shape, n1, n2); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, int &n1, int &n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma, beta); +} +} // namespace + +void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, + at::Tensor *input, int n1, int n2, + at::IntArrayRef normalized_shape, at::Tensor *gamma, + at::Tensor *beta, double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine(at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, + double epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = + at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = + at::empty({n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, + &gamma, &beta, epsilon); + + return {output, mean, invvar}; +} + +void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, + at::Tensor *invvar, at::Tensor *input, int n1, + int n2, at::IntArrayRef normalized_shape, + at::Tensor *gamma, at::Tensor *beta, + double epsilon, at::Tensor *grad_input, + at::Tensor *grad_gamma, at::Tensor *grad_beta); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, + at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, + double epsilon) { + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..81014daa2a794dfa20e22d75669046ad99b1e096 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -0,0 +1,706 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" +#include "type_shim.h" + +template +__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, + U& mu, U& sigma2, U& count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, + const int n2, const int i1, U& mu, U& sigma2, + U* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2 * threadIdx.y]; + U sigma2B = ubuf[2 * threadIdx.y + 1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, + const int n1, const int n2, const int i1, + float& mu, float& sigma2, float* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2*)(lvals + l + k))); + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2 * threadIdx.y]; + float sigma2B = ubuf[2 * threadIdx.y + 1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); + } + } +} + +#ifdef COLOSSAL_HIP + template __device__ + U rsqrt(U v) { + return U(1) / sqrt(v); + } + template<> __device__ + float rsqrt(float v) { + return rsqrtf(v); + } + template<> __device__ + double rsqrt(double v) { + return rsqrt(v); + } +#else + template + U rsqrt(U v) { + return U(1) / sqrt(v); + } + template<> + float rsqrt(float v) { + return rsqrtf(v); + } + template<> + double rsqrt(double v) { + return rsqrt(v); + } +#endif + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} // namespace + +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, + U* __restrict__ mean, U* __restrict__ invvar, + const T* __restrict__ vals, const int n1, + const int n2, const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // +#ifdef COLOSSAL_HIP + for (size_t i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#else + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +template +__device__ void cuLoadWriteStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta( + const V* __restrict__ dout, const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + U epsilon, U* part_grad_gamma, U* part_grad_beta) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, const int n1, + const int n2, V* grad_gamma, + V* grad_beta) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx + nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx + nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, + const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, U epsilon, + const V* gamma, T* grad_input) { +#ifdef COLOSSAL_HIP + for (size_t i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#else + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2 * wrt_i] = sum_loss1; + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2 * read_i]; + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2 * threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + sum_loss1 = buf[2 * threadIdx.x]; + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + } +} + +template +void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1, + int n2, double epsilon, const V* gamma, const V* beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, + at::Tensor* input, int n1, int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, double epsilon) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm(output->DATA_PTR(), + mean->DATA_PTR(), invvar->DATA_PTR(), + input->DATA_PTR(), n1, n2, epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL);) +} + +template +void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, + at::Tensor* input, int n1, int n2, const V* gamma, + const V* beta, double epsilon, T* grad_input, + V* grad_gamma, V* grad_beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size, n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), + part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR()); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), part_size, + n1, n2, grad_gamma, grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>( + dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), gamma, + grad_input); +} + +void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, + at::Tensor* invvar, at::Tensor* input, int n1, + int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, + double epsilon, at::Tensor* grad_input, + at::Tensor* grad_gamma, at::Tensor* grad_beta) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), mean->DATA_PTR(), + invvar->DATA_PTR(), input, n1, n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL);) +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8c0b89eb06d16d5a35c273acb755642399398750 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -0,0 +1,97 @@ +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..eec74034e7e41823b26f1e5f61dcb3bfabec0cc9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -0,0 +1,714 @@ +#include +#include +#include + +#include + +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + +#ifdef COLOSSAL_HIP + typedef hipcub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef hipcub::BlockStore + BlockStore; +#else + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; +#endif + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), + logits.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), + expert_tokens.data(), logits.data(), + wgrad.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu new file mode 100644 index 0000000000000000000000000000000000000000..afd34bb96352873f12824a07cbf0062f5983a18a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu @@ -0,0 +1,141 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +struct AdamFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta1_correction, + const float beta2_correction, const float epsilon, const float lr, + adamMode_t mode, const float decay, const float div_scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_g *g = (T_g *)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T_p *p = (T_p *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T_p *m = (T_p *)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T_p *v = (T_p *)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (div_scale > 0) r_g[ii] /= div_scale; + + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int mode, + const int bias_correction, const float weight_decay, + const float div_scale) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + DISPATCH_FLOAT_AND_HALF_FOR_G_P( + tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay, div_scale);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9ce41191133eeac43534d5d4d33dc0071c4ccb27 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh @@ -0,0 +1,133 @@ +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh +#include +#include +#include +#include +#include +#include "compat.h" + +#include + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata +{ + void *addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; +}; + +template +__global__ void multi_tensor_apply_kernel( + int chunk_size, + volatile int *noop_flag, + T tl, + U callable, + ArgTypes... args) +{ + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply( + int block_size, + int chunk_size, + const at::Tensor &noop_flag, + const std::vector> &tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) + { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) + { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) + { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) + { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, + noop_flag.DATA_PTR(), + tl, + callable, + args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) + { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } + else + { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..49ab83e8fc81df4c9887d55ddc5503f20498bb7d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -0,0 +1,382 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu +#include +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +template +struct L2NormFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be + // sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] += next * next; + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] += final; + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +// Probably better to template, but since we are not likely to support other +// norm +template +struct MaxNormFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be + // sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); + + float final = reduce_block_into_lanes_max_op(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +__global__ void cleanup(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) *ret = sqrt(final); + } + + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); + } +} + +__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor, int norm_type, + float alpha, float beta) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + if (norm_type == 0) { + float final = reduce_block_into_lanes_max_op(vals, val); + if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; + } else { + float final = reduce_block_into_lanes(vals, val); + if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); + } + } + + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + if (norm_type == 0) { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); + + float final = reduce_block_into_lanes_max_op(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = + alpha * ret_per_tensor[blockIdx.x] + beta * final; + } else { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * + ret_per_tensor[blockIdx.x] + + beta * final); + } + } +} + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { + bool per_tensor = + per_tensor_python.has_value() ? per_tensor_python.value() : false; + + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + auto output = at::zeros({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + if (per_tensor) { + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + ret_per_tensor = at::empty({ntensors}, float_options); + } else { + ret_per_tensor = at::empty({0}, float_options); + } + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.DATA_PTR(), + per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + per_tensor, max_chunks_per_tensor);) + + AT_CUDA_CHECK(cudaGetLastError()); + // AT_CUDA_CHECK(cudaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup<<>>( + output.DATA_PTR(), + per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + ret.DATA_PTR(), + per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, per_tensor, + max_chunks_per_tensor); + + return std::tuple(ret, ret_per_tensor); +} + +// Compute and update grad norm +// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by +// L-2: gn = sqrt(a * gn^2 + b * n^2) +// L-inf: gn = a * gn + b * n +void multi_tensor_norm_out_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor out, + const float alpha, const float beta, const int norm_type) { + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), + "noop flag should be on the same device as tensors"); + // we don't need global thus uses empty here + auto output = at::empty({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + + // Although it is single write then read, still need to be zero + // Since tailing element also participate cleanup + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + + if (norm_type == 0) { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + MaxNormFunctor(), output.DATA_PTR(), + output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + } else { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.DATA_PTR(), + output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + } + AT_CUDA_CHECK(cudaGetLastError()); + + // AT_CUDA_CHECK(cudaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + + // Adding the following device guard since it happens sometimes that the + // tensors are on one device and the cuda stream is on another device which + // results in ILLEGAL MEM ACCESS error. + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup_v2<<>>( + output.DATA_PTR(), output_per_tensor.DATA_PTR(), + ret.DATA_PTR(), out.DATA_PTR(), true, max_chunks_per_tensor, + norm_type, alpha, beta); + + return; +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c4220190d80d6309e74c90da412d6ccda32c8f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -0,0 +1,354 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +typedef enum { + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode +} adamMode_t; + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); + +using MATH_T = float; + +template +struct LAMBStage1Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float *global_grad_norm, const float max_global_grad_norm) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + float clipped_global_grad_norm = + (*global_grad_norm) > max_global_grad_norm + ? (*global_grad_norm) / max_global_grad_norm + : 1.0f; + + T *g = (T *)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T *p = (T *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T *m = (T *)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T *v = (T *)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && + is_aligned(p) && is_aligned(m) && is_aligned(v)) { + T l_g[ILP]; + T l_p[ILP]; + T l_m[ILP]; + T l_v[ILP]; + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(l_g, g, 0, i_start); + if (decay != 0) load_store(l_p, p, 0, i_start); + load_store(l_m, m, 0, i_start); + load_store(l_v, v, 0, i_start); + // unpack +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_g[ii] = l_g[ii]; + if (decay == 0) { + r_p[ii] = MATH_T(0); + } else { + r_p[ii] = l_p[ii]; + } + r_m[ii] = l_m[ii]; + r_v[ii] = l_v[ii]; + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == MOMENT_MODE_0) { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = next_m_unbiased / denom; + } else { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + l_p[ii] = r_p[ii]; + l_m[ii] = r_m[ii]; + l_v[ii] = r_v[ii]; + } + // store + load_store(g, l_p, i_start, 0); + load_store(m, l_m, i_start, 0); + load_store(v, l_v, i_start, 0); + } + } else { + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + // special ?optimization? for lamb stage 1 + if (decay == 0) { + r_p[ii] = MATH_T(0); + } else { + r_p[ii] = p[i]; + } + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == MOMENT_MODE_0) { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = next_m_unbiased / denom; + } else { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + g[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } + } +}; + +// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. +// It computes new parameter value. +template +struct LAMBStage2Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, + const float *per_tensor_param_norm, const float *per_tensor_update_norm, + const float learning_rate, const float decay, bool use_nvlamb) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int tensor_num = tl.start_tensor_this_launch + tensor_loc; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + MATH_T ratio = learning_rate; + // nvlamb: apply adaptive learning rate to all parameters + // otherwise, only apply to those with non-zero weight decay + if (use_nvlamb || (decay != 0.0)) { + float param_norm = per_tensor_param_norm[tensor_num]; + float update_norm = per_tensor_update_norm[tensor_num]; + ratio = (update_norm != 0.0f && param_norm != 0.0f) + ? learning_rate * (param_norm / update_norm) + : learning_rate; + } + + T *update = (T *)tl.addresses[0][tensor_loc]; + update += chunk_idx * chunk_size; + + T *p = (T *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && + is_aligned(update)) { + T r_p[ILP]; + T r_update[ILP]; + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_p, p, 0, i_start); + load_store(r_update, update, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = static_cast(r_p[ii]) - + (ratio * static_cast(r_update[ii])); + } + load_store(p, r_p, i_start, 0); + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_p[ILP]; + MATH_T r_update[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_p[ii] = p[i]; + r_update[ii] = update[i]; + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = r_p[ii] - (ratio * r_update[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + } + } + } + } + } +}; + +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int bias_correction, + const float weight_decay, const int grad_averaging, + const int mode, at::Tensor global_grad_norm, + const float max_grad_norm, + at::optional use_nvlamb_python) { + using namespace at; + // Master weight and 32bit momentum(potentially changing) is not handled by + // this So we assume every tensor are all in the same type + + bool use_nvlamb = + use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + // Handle grad averaging mode + float beta3 = 1.0f; + if (grad_averaging == 1) beta3 = 1 - beta1; + + std::vector> grad_list(tensor_lists.begin(), + tensor_lists.begin() + 1); + std::vector> param_list(tensor_lists.begin() + 1, + tensor_lists.begin() + 2); + + // Compute per tensor param norm + auto param_norm_tuple = + multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); + + // We now in-place modify grad to store update before compute its norm + // Generally this is not a issue since people modify grad in step() method all + // the time We can also grab list of empty tensor to avoid this, but I'd like + // to save space/cpu code + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + LAMBStage1Functor(), beta1, beta2, + beta3, // 1-beta1 or 1 depends on averaging mode + bias_correction1, bias_correction2, epsilon, + (adamMode_t)mode, weight_decay, + global_grad_norm.DATA_PTR(), max_grad_norm);) + + // Compute update norms + auto update_norm_tuple = + multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); + + std::vector> grad_param_list( + tensor_lists.begin(), tensor_lists.begin() + 2); + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, + LAMBStage2Functor(), + std::get<1>(param_norm_tuple).DATA_PTR(), + std::get<1>(update_norm_tuple).DATA_PTR(), + lr, weight_decay, use_nvlamb);) + + AT_CUDA_CHECK(cudaGetLastError()); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..360485dcd02fbfc21a76a2bfa6dd6568b8909499 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -0,0 +1,125 @@ +#include +#include +#include +#include +// Another possibility: +// #include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +template +struct ScaleFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int *noop_gmem, + TensorListMetadata<2> &tl, + float scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = (in_t *)tl.addresses[0][tensor_loc]; + in += chunk_idx * chunk_size; + + out_t *out = (out_t *)tl.addresses[1][tensor_loc]; + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && + is_aligned(out)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point + // unrolling the write loop, since writes just fire off once their LDGs + // arrive. Put another way, the STGs are dependent on the LDGs, but not + // on each other. There is still compute ILP benefit from unrolling the + // loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; + } + } + } + if (!finite) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + } +}; + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float scale) { + using namespace at; + // The output (downscaled) type is always float. + // If build times suffer, think about where to put this dispatch, + // and what logic should be moved out of multi_tensor_apply. + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", + DISPATCH_FLOAT_AND_HALF( + tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ScaleFunctor(), + scale);)) + AT_CUDA_CHECK(cudaGetLastError()); + + // AT_CUDA_CHECK(cudaDeviceSynchronize()); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..35f2c9b4ed15eab94b1456ce436694180d706a45 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -0,0 +1,167 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu +#include +#include +#include +#include +#include +#include + +#include "compat.h" +#include "multi_tensor_apply.cuh" + +#define BLOCK_SIZE 512 +#define ILP 4 + +/** + * Perform fused SGD on multiple buffers + * N: number of tensors + * tl[0] : gradients + * tl[1] : weights + * tl[2] : momentum buffers + * tl[3] : fp16 weights (if appropriate) + * wd : weight_decay (scalar) + * momentum : momentum (scalar) + * dampening : momentum dampening (scalar) + * lr : learning rate (scalar) + * nesterov : enable nesterov (bool) + * first run : necessary for proper momentum handling & init + * wd_after_momentum : apply weight decay _after_ momentum instead of before + **/ +template +struct SGDFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, + float wd, float momentum, float dampening, float lr, bool nesterov, + bool first_run, bool wd_after_momentum, float scale) { + // Early exit if we don't need to do anything + if (*noop_gmem) return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; + grad_in += chunk_idx * chunk_size; + + T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; + weight_in += chunk_idx * chunk_size; + + T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; + mom_in += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + incoming_grads[ii] = static_cast(grad_in[i]) * scale; + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } + +// note for clarification to future michael: +// From a pure memory dependency perspective, there's likely no point unrolling +// the write loop, since writes just fire off once their LDGs arrive. +// Put another way, the STGs are dependent on the LDGs, but not on each other. +// There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + // apply weight decay before momentum if necessary + if (wd != 0.f && !wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; + + if (momentum != 0.f) { + if (!first_run) + incoming_moms[ii] = incoming_moms[ii] * momentum + + (1.f - dampening) * incoming_grads[ii]; + else // initialize momentums to current incoming grads + incoming_moms[ii] = incoming_grads[ii]; + + if (nesterov) + incoming_grads[ii] += momentum * incoming_moms[ii]; + else + incoming_grads[ii] = incoming_moms[ii]; + } + + // Apply WD after momentum if desired + if (wd != 0.f && wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; + + // adjust the weight and write out + weight_in[i] += (-lr * incoming_grads[ii]); + + // also write out the new momentum + if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; + } + } + } + } +}; + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale) { + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); + + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), + "expected noop flag to be on the same device as tensors"); + + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type + // 1. fp16, fp16, fp16 + // 2. fp32, fp32, fp32 + // 3. fp16, fp32, fp32 + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. + + // Case 1. fp16, fp16, fp16, No + if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Half && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } + // Case 2. fp32, fp32, fp32 + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, dampening, + lr, nesterov, first_run, wd_after_momentum, scale); + } + // Case 3. fp16, fp32, fp32 + else if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } else { + AT_ERROR( + "multi_tensor_sgd only supports some combinations of gradient & weight " + "types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, + ", num_lists: ", num_tensors); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..166c698f617bf37b73a6c290dafc197f05b25980 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -0,0 +1,405 @@ +#include "multihead_attention_1d.h" + +#include +#include +#include + +#if TORCH_VERSION_MINOR >= 13 +#include +#else +#include +#endif +#include + +#include "context.h" +#include "kernels.h" + +template +MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, + int max_seq_len, int hidden_size, + int num_heads, + float attn_prob_dropout_ratio, + float hidden_output_dropout_ratio, + bool pre_or_postLayerNorm) + : _layer_id(layer_id), + _max_batch_tokens(max_batch_tokens), + _max_seq_len(max_seq_len), + _hidden_size(hidden_size), + _heads(num_heads), + _training(true), + _pre_or_postLayerNorm(pre_or_postLayerNorm), + _qkv_linear( + typename FeedForward::Config(3 * hidden_size, hidden_size)), + _attn_out_linear( + typename FeedForward::Config(hidden_size, hidden_size)), + _attn_ln(typename Normalize_Layer::Config(hidden_size, false), + _max_batch_tokens), + _softmax(typename Softmax::Config(num_heads)), + _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), + _max_batch_tokens * _heads * _max_seq_len), + _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), + _max_batch_tokens * _hidden_size), + _attn_scores(typename StridedBatchGemm::Config( + (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, + CUBLAS_OP_N)), + _attn_context(typename StridedBatchGemm::Config( + T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { + assert(_hidden_size % _heads == 0); +} + +template +MultiHeadAttention::~MultiHeadAttention() { + free_mem_buffer(); +} + +template +void MultiHeadAttention::attn_layer_fw(const T *input_ptr, + const T *input_mask_ptr, + T *output_ptr, T *buffer) { + T *q_tf_ptr = _qkv_ptr; + T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; + T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; + + if (_pre_or_postLayerNorm) { + _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, + _batch_tokens, _stream); + } + const T *gemmQKV_inp_ptr = + _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); + _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, + _cublasHandle); + + launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, + _batch_size, _seq_len, 3, _heads / pg_size, + _hidden_size / _heads, _stream); + + // attention scores, q*k + _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, + _cublasHandle); + + // Softmax + Mask + _softmax.reset_size(_heads / pg_size); + _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, + _seq_len, _stream, true); + + // attn prob dropout. + _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, + _batch_heads * _seq_len * _seq_len, _stream); + + // attention context, score * v + _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, + _cublasHandle); + + // [b, nh, s, ad] -> [b, s, nh, ad] + launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, 1, + _stream); + + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); + _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, + output_ptr, _cublasHandle); + + // allreduce + if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { + } else { + auto data_type = torch::kFloat; + if (typeid(T) != typeid(float)) { + data_type = torch::kHalf; + } + auto output_tensor = torch::from_blob( + output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); + std::vector allreduce_tensors = {output_tensor}; + auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); + work->wait(); + } + + _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, + _attn_ob_ptr, _batch_tokens, _hidden_size, + _stream); + if (!_pre_or_postLayerNorm) { + // in-place ln since ln-input will not be used in post-ln mode + _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, + _batch_tokens, _stream); + } +} + +template +void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, + T *out_ptr) { + _stream = Context::Instance().get_stream(); + _cublasHandle = Context::Instance().get_cublashandle(); + T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim + + attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); +} + +template +void MultiHeadAttention::attn_layer_bw(const T *input_ptr, + const T *input_mask_ptr, + const T *output_ptr, + const T *grad_output_ptr, + T *grad_input_ptr, T *buffer) { + cudaStream_t streams[2] = {_stream, _stream}; + + const T *q_tf_ptr = _qkv_ptr; + const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; + const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; + // batch_dim = batch_size * seq_len * hidden_size + // buffer size: batch_dim * 3 + max(batch_dim * 3, + // batch_size * head_num * seq_len * seq_len) + T *grad_residual_ptr = buffer; + buffer += _batch_dim; + + T *grad_input_buf_ptr = buffer; // batch_dim + T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 + buffer += 3 * _batch_dim / pg_size; + + T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 + T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len + // buffer += max(3 * _batch_dim, + // batch_size * head_num * seq_len * seq_len); + + if (_pre_or_postLayerNorm) { + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, + grad_output_ptr, _batch_tokens, + _hidden_size, _stream); + } else { + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, + grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, + _attn_nb_ptr, _batch_tokens, streams); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, + grad_residual_ptr, _batch_tokens, + _hidden_size, _stream); + } + + // bw of output project + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); + _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, + _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, + _cublasHandle, _stream, grad_input_buf_ptr, nullptr, + false); + launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, + _seq_len, _hidden_size / pg_size, _heads / pg_size, + _stream); + + // bw of score * v + _attn_context.Backward( + _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, + grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); + + _attn_prob_dropout.d_dropout(grad_softmax_ptr, + _batch_heads * _seq_len * _seq_len, _stream); + + _softmax.reset_size(_heads / pg_size); + _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, + _seq_len, _stream); + + // bw of q * k + _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, + _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, + grad_qkv_5d_ptr); + + // [3, b, nh, s, ad] -> [b, s, 3, h] + launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, + _seq_len, _hidden_size / pg_size, _heads / pg_size, + 3, _stream); + + const T *gemmQKV_inp_ptr = + _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); + _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, + _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, + _cublasHandle, _stream, grad_input_buf_ptr, nullptr, + true); + + // allreduce + if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { + } else { + auto data_type = torch::kFloat; + if (typeid(T) != typeid(float)) { + data_type = torch::kHalf; + } + auto grad_input_tensor = + torch::from_blob(grad_input_buf_ptr, + {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); + std::vector allreduce_tensors = {grad_input_tensor}; + auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); + work->wait(); + } + + if (_pre_or_postLayerNorm) { + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, + grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, + _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); + } else { + // FIXME later + launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, + _batch_size, _seq_len, _hidden_size, _stream); + } +} + +template +void MultiHeadAttention::Backward(const T *grad_output_ptr, + const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, + T *grad_input_ptr) { + _stream = Context::Instance().get_stream(); + _cublasHandle = Context::Instance().get_cublashandle(); + T *buffer = _shared_mem_ptr; + + /* + buffer size needed by attn bw: + 4 * _batch_dim + max(3 * _batch_dim, + _batch_size * _head_num * _seq_len * _seq_len); + */ + attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, + grad_input_ptr, buffer); +} + +template +void MultiHeadAttention::SetTrainingMode(bool training) { + // Dropout will be skipped when not in training model. + _attn_prob_dropout.SetTrainingMode(training); + _attn_dropout.SetTrainingMode(training); +} + +template +T *MultiHeadAttention::_shared_mem_ptr = nullptr; + +template class MultiHeadAttention; +template class MultiHeadAttention<__half>; + +// x is torch::Tensor +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +static std::unordered_map> s_multihead_attention; + +template +int create_multihead_attention(int layer_id, int max_batch_tokens, + int max_seq_len, int hidden_dim, int num_heads, + float attn_prob_dropout_ratio, + float hidden_dropout_ratio, + bool pre_or_postLayerNorm, + c10::intrusive_ptr pg_) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + Context::Instance().set_stream(stream); + auto layer = std::make_shared>( + layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, + attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); + + layer->SetPG(pg_); + + s_multihead_attention[layer_id] = layer; + + std::string dtype = (std::is_same::value) ? "half" : "float"; + + return 0; +} + +template +std::vector multihead_attention_fw( + int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, + bool training_mode, bool prelayernorm) { + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + + const T *input_ptr = (const T *)input.data_ptr(); + const T *input_mask_ptr = (const T *)input_mask.data_ptr(); + + auto output = torch::empty_like(input); + T *out_ptr = (T *)output.data_ptr(); + + std::shared_ptr> layer = + std::static_pointer_cast>( + s_multihead_attention[layer_id]); + layer->set_cur_batch_shape(input.size(0), input.size(1)); + layer->SetTrainingMode(training_mode); + + layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); + layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); + layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); + layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); + layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); + layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); + + layer->Forward(input_ptr, input_mask_ptr, out_ptr); + + return {output}; +} + +template +std::vector multihead_attention_bw( + int layer_id, const torch::Tensor &grad_dec_output, + const torch::Tensor &output, const torch::Tensor &input, + const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias) { + auto g_output = grad_dec_output.contiguous(); + CHECK_INPUT(g_output); + CHECK_INPUT(output); + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + + auto grad_input = torch::empty_like(input); + auto grad_in_proj_weight = torch::empty_like(in_proj_weight); + auto grad_in_proj_bias = torch::empty_like(in_proj_bias); + auto grad_out_proj_weight = torch::empty_like(out_proj_weight); + auto grad_out_proj_bias = torch::empty_like(out_proj_bias); + auto grad_norm_weight = torch::empty_like(norm_weight); + auto grad_norm_bias = torch::empty_like(norm_bias); + + // inputs. + const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); + const T *input_ptr = (const T *)input.data_ptr(); + const T *output_ptr = (const T *)output.data_ptr(); + const T *input_mask_ptr = (const T *)input_mask.data_ptr(); + + // outputs. + T *grad_input_ptr = (T *)grad_input.data_ptr(); + + std::shared_ptr> layer = + std::static_pointer_cast>( + s_multihead_attention[layer_id]); + layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); + + layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); + layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); + layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); + layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); + layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); + layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); + + layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, + grad_input_ptr); + + return {grad_input, grad_in_proj_weight, grad_in_proj_bias, + grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, + grad_norm_bias}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("multihead_attention_fw_fp32", &multihead_attention_fw, + "Multi-head Attention forward with fp32 (CUDA)"); + m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, + "Multi-head Attention forward with fp16 (CUDA)"); + m.def("multihead_attention_bw_fp32", &multihead_attention_bw, + "Multi-head Attention backward with fp32 (CUDA)"); + m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, + "Multi-head Attention backward with fp16 (CUDA)"); + m.def("create_multihead_attention_fp32", &create_multihead_attention, + "Create Multi-head Attention with fp32 (CUDA)"); + m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, + "Create Multi-head Attention with fp16 (CUDA)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h new file mode 100644 index 0000000000000000000000000000000000000000..061abb89f0327ee73c214898ee6017f137bb3a69 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -0,0 +1,170 @@ +#pragma once + +#include +#include +#include +#include +#include + +#if TORCH_VERSION_MINOR >= 13 +#include +#else +#include +#endif + +#include +#include + +#ifdef COLOSSAL_HIP +#include "hip_util.h" +#else +#include "cuda_util.h" +#endif +#include "dropout.h" +#include "feed_forward.h" +#include "normalize_layer.h" +#include "softmax.h" +#include "strided_batch_gemm.h" + +template +class MultiHeadAttention { + public: + MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, + int hidden_size, int num_heads, float attn_dropout_ratio, + float hidden_output_dropout_ratio, + bool pre_or_postLayerNorm); + + virtual ~MultiHeadAttention(); + + void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); + + void Backward(const T *grad_output_ptr, const T *input_ptr, + const T *output_ptr, const T *input_mask_ptr, + T *grad_input_ptr); + + void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, + T *buffer); + + void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, + const T *output_ptr, const T *grad_output_ptr, + T *grad_input_attn_layer_bwptr, T *buffer); + + void set_cur_batch_shape(int batch_size, int seq_len) { + _batch_size = batch_size; + _seq_len = seq_len; + _batch_tokens = batch_size * seq_len; + _batch_heads = batch_size * _heads / pg_size; + _batch_dim = _batch_tokens * _hidden_size; + _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); + _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); + } + + void SetTrainingMode(bool training); + inline bool IsTrainingMode() const { return _training; } + + void SetPG(c10::intrusive_ptr pg_) { + pg = pg_; + pg_size = 1; + if (pg != c10::detail::UniqueVoidPtr()) { + pg_size = pg->getSize(); + } + allocate_mem_buffer(); + } + + // weights ptr + const T *_attn_qkvw_ptr; + const T *_attn_qkvb_ptr; + const T *_attn_ow_ptr; + const T *_attn_ob_ptr; + const T *_attn_nw_ptr; + const T *_attn_nb_ptr; + + // grads ptr + T *_grad_attn_qkvw_ptr; + T *_grad_attn_qkvb_ptr; + T *_grad_attn_ow_ptr; + T *_grad_attn_ob_ptr; + T *_grad_attn_nw_ptr; + T *_grad_attn_nb_ptr; + + private: + void allocate_mem_buffer() { + // allocate local gpu memory + if (_pre_or_postLayerNorm) { + _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); + } else { + _gemmQKV_inp_ptr = nullptr; + } + + _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); + _soft_out_ptr = + cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _ctx_bufB_ptr = + cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); + + // buffer size needed by attn bw + size_t smem_size = + 4 * _max_batch_tokens * _hidden_size / pg_size + + std::max(3 * _max_batch_tokens * _hidden_size / pg_size, + _max_batch_tokens * _heads / pg_size * _max_seq_len); + + if (!_shared_mem_ptr) { + cuda_free(_shared_mem_ptr); + _shared_mem_ptr = cuda_malloc(smem_size); + } + } + + void free_mem_buffer() { + // free local gpu memory + cuda_free(_gemmQKV_inp_ptr); + cuda_free(_qkv_ptr); + cuda_free(_soft_out_ptr); + cuda_free(_ctx_bufB_ptr); + cuda_free(_attn_o_inp_ptr); + + // free shared gpu memory between layers + cuda_free(_shared_mem_ptr); + _shared_mem_ptr = nullptr; + } + + // const parameter between batch + const size_t _layer_id; + const size_t _hidden_size; + const size_t _heads; + const size_t _max_batch_tokens; + const size_t _max_seq_len; + const bool _pre_or_postLayerNorm; + // dynamic parameter between batch + size_t _batch_size; + size_t _seq_len; + size_t _batch_tokens; + size_t _batch_heads; + size_t _batch_dim; + bool _training; + + cublasHandle_t _cublasHandle; + cudaStream_t _stream; + + // layers + FeedForward _qkv_linear; + FeedForward _attn_out_linear; + Normalize_Layer _attn_ln; + Softmax _softmax; + Dropout _attn_prob_dropout; + Dropout _attn_dropout; + StridedBatchGemm _attn_scores; + StridedBatchGemm _attn_context; + + // local GPU memory + T *_gemmQKV_inp_ptr; + T *_qkv_ptr; + T *_soft_out_ptr; + T *_ctx_bufB_ptr; + T *_attn_o_inp_ptr; + // shared GPU memory between layer + static T *_shared_mem_ptr; + + c10::intrusive_ptr pg; + int pg_size; +}; diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ae3c853ca5e844272ca4fdb907c8c95a7f2b787 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -0,0 +1,84 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..1583030b8235acfb3a3af1a86fa938901ae52bbb --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -0,0 +1,492 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..0d8678e4bcd58be56ae7faafbc7f403b10a12e51 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu @@ -0,0 +1,91 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#ifndef COLOSSAL_HIP +#include +#endif +#include +#include + +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = torch::empty( + {batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); + + // backward pass is completely in-place + return output_grads; +} +} // namespace scaled_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cbbc3706497a69ea142e2898acc321c383fd1939 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,54 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include + +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..3af487f9de0ffdc22faaca142cbc2ff86b68d03e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,500 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..09cf612d7e699f24758850f7f02d5290ac034373 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,77 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#ifndef COLOSSAL_HIP +#include +#endif +#include +#include + +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), scale_factor, seq_len, + seq_len, attn_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 3d tensor with dimensions [attn_batches, seq_len, + // seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, seq_len, seq_len, attn_batches);); + + // backward pass is completely in-place + return output_grads; +} +} // namespace scaled_upper_triang_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..41c96856c67866e37185cbf8d3f0c1808f4ba358 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -0,0 +1,310 @@ +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + +// Forward/backward compatiblity hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Byte: \ + { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ + if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \ + { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } \ + else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \ + { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } \ + else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \ + { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } \ + else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \ + { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } \ + else \ + { \ + AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \ + } \ + +template +__device__ __forceinline__ T reduce_block_into_lanes(T *x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) + { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if (tid < i) + x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) + { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) +#ifdef COLOSSAL_HIP + final = final + __shfl_down(final, i); +#else + final = final + __shfl_down_sync(0xffffffff, final, i); +#endif + } + + if (share_result) + { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) + { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if (tid < i) + x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) + { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) +#ifdef COLOSSAL_HIP + final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i))); +#else + final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); +#endif + } + + if (share_result) + { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd646d3935f2eb6a2642a2027cff6bf98bf7a07 --- /dev/null +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -0,0 +1,525 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) +""" + +import math +import os +import subprocess + +import torch + + +def triton_cuda_check(): + cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") + cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip() + cuda_version = cuda_version.split('release ')[1] + cuda_version = cuda_version.split(',')[0] + cuda_version = cuda_version.split('.') + if len(cuda_version) == 2 and \ + (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \ + int(cuda_version[0]) > 11: + return True + return False + + +try: + import triton + import triton.language as tl + if triton_cuda_check(): + HAS_TRITON = True + else: + print("triton requires cuda >= 11.4") + HAS_TRITON = False +except ImportError: + print('please install triton from https://github.com/openai/triton') + HAS_TRITON = False +try: + from flash_attn.flash_attention import FlashAttention + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_func, + flash_attn_unpadded_kvpacked_func, + flash_attn_unpadded_qkvpacked_func, + ) + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + print('please install flash_attn from https://github.com/HazyResearch/flash-attention') + +try: + from xformers.ops.fmha import memory_efficient_attention + HAS_MEM_EFF_ATTN = True +except ImportError: + HAS_MEM_EFF_ATTN = False + print('please install xformers from https://github.com/facebookresearch/xformers') + +if HAS_TRITON: + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + sm_scale, + TMP, + L, + M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + @triton.jit + def _bwd_preprocess( + Out, + DO, + L, + NewDO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + @triton.jit + def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + M, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + num_block, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + class _TritonFlashAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + tmp, + L, + m, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( + o, + do, + l, + do_scaled, + delta, + BLOCK_M=ctx.BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + + # NOTE: kernel currently buggy for other values of `num_warps` + num_warps = 8 + _bwd_kernel[(ctx.grid[1],)]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + dk, + dv, + l, + m, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, + BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + return dq, dk, dv, None + + def triton_flash_attention(q, k, v, sm_scale): + """ + Arguments: + q: (batch, nheads, seq, headdim) + k: (batch, nheads, seq, headdim) + v: (batch, nheads, seq, headdim) + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (batch, nheads, seq, headdim) + """ + if HAS_TRITON: + return _TritonFlashAttention.apply(q, k, v, sm_scale) + else: + raise RuntimeError("Triton kernel requires CUDA 11.4+!") + + +if HAS_FLASH_ATTN: + + from einops import rearrange + + class MaskedFlashAttention(torch.nn.Module): + + def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None: + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size), + attention_dropout=attention_dropout) + + def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False): + if attention_mask.dtype is not torch.bool: + attention_mask = attention_mask.bool() + qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads) + context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal) + context = rearrange(context, 'b s h d -> b s (h d)') + return context + + def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): + """ + Arguments: + qkv: (batch * seqlen, 3, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + max_s = seq_len + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func(qkv, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=sm_scale, + causal=causal) + return out + + def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + kv: (batch * kv_seqlen, 2, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, + sm_scale, causal) + return out + + def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + k: (batch * kv_seqlen, nheads, headdim) + v: (batch * kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=k.device) + return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, + causal) + + +if HAS_MEM_EFF_ATTN: + + from einops import rearrange + from xformers.ops.fmha import LowerTriangularMask + + class MemoryEfficientAttention(torch.nn.Module): + + def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0): + super().__init__() + attention_head_size = hidden_size // num_attention_heads + self.scale = 1 / attention_head_size**0.5 + self.dropout = attention_dropout + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor): + context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale) + context = rearrange(context, 'b s h d -> b s (h d)') + return context diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b5efa4ec8ce24371129e76e7309e7be7641f24 --- /dev/null +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -0,0 +1,76 @@ +"""This code is from NVIDIA apex: + https://github.com/NVIDIA/apex + with some changes. """ + +import numbers + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn import init +from torch.nn.parameter import Parameter + + +class FusedLayerNormAffineFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, weight, bias, normalized_shape, eps): + try: + import colossalai._C.layer_norm + except ImportError: + raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions') + + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, + ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + try: + import colossalai._C.layer_norm + except ImportError: + raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions') + + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = colossalai._C.layer_norm.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class MixedFusedLayerNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): + super(MixedFusedLayerNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) + self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) + self.reset_parameters() + + def reset_parameters(self): + + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + + return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) + + def __repr__(self): + return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..84cae529a2b502b5a8b0b54b2eb3691e22602b28 --- /dev/null +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -0,0 +1,259 @@ +import math +from dataclasses import dataclass + +import torch +from torch import nn +from torch.autograd import Function + + +def check_config(config): + if config.hidden_size % config.nhead != 0: + raise Exception("hidden_size % nhead != 0") + + factor = 8 if config.fp16 else 4 + upbound = factor * 1024 * 4 + if config.hidden_size > upbound: + # as required by ln backward kernel currently + raise Exception(f"hidden_size > {upbound}") + + head_dim = config.hidden_size // config.nhead + if head_dim % factor != 0: + # as required by reshape kernel + raise Exception(f"head_dim({head_dim}) % {factor} != 0") + + +def calc_offset(sizes): + offsets = [0] + tmp = 0 + for x in sizes: + tmp += x + offsets.append(tmp) + return offsets + + +colossal_multihead_attention = None + + +@dataclass +class Config: + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 presion + + +class MultiHeadAttention1DFunc(Function): + + @staticmethod + def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, + norm_bias, config): + cuda_module = colossal_multihead_attention + forward_func = (cuda_module.multihead_attention_fw_fp16 + if config.fp16 else cuda_module.multihead_attention_fw_fp32) + if config.fp16: + input = input.to(torch.half) + input_mask = input_mask.to(torch.half) + + (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) + + if config.is_grad_enabled and config.training: + ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias) + ctx.config = config + return output + + @staticmethod + def backward(ctx, grad_output): + assert ctx.config.training + + cuda_module = colossal_multihead_attention + backward_func = (cuda_module.multihead_attention_bw_fp16 + if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) + + output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ + out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors + + grad_input = None + grad_in_proj_weight = None + grad_in_proj_bias = None + grad_out_proj_weight = None + grad_out_proj_bias = None + grad_norm_weight = None + grad_norm_bias = None + + if ctx.config.fp16: + grad_output = grad_output.to(torch.half) + output = output.to(torch.half) + input = input.to(torch.half) + input_mask = input_mask.to(torch.half) + grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ + grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( + ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, + in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) + + return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, + grad_norm_weight, grad_norm_bias, None) + + +class MultiHeadAttention(nn.Module): + """Initialize the MultiHeadAttention. + + Static variable: + + layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, + e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. + + Arguments: + hidden_size: Total dimension of hidden_size. + nhead: Number of parallel attention heads. + batch_size: Batch Size for one foward + max_seq_len: Max length of input sequence + dropout: Dropout probability + norm_first: perform LayerNorms before attention + """ + + layer_id = 0 + + def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): + super(MultiHeadAttention, self).__init__() + + self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, + fp16) + check_config(self.config) + self.pg = pg + self.pg_size = 1 + if self.pg: + self.pg_size = pg.size() + self.config.layer_id = MultiHeadAttention.layer_id + MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 + + # Load cuda modules if needed + global colossal_multihead_attention + if colossal_multihead_attention is None: + try: + import colossalai._C.multihead_attention + colossal_multihead_attention = colossalai._C.multihead_attention + except ImportError: + raise RuntimeError('MultiHeadAttention requires cuda extensions') + + # create the layer in cuda kernels. + cuda_module = colossal_multihead_attention + create_layer_func = (cuda_module.create_multihead_attention_fp16 + if self.config.fp16 else cuda_module.create_multihead_attention_fp32) + + create_layer_func( + self.config.layer_id, + self.config.max_batch_tokens, + self.config.max_seq_len, + self.config.hidden_size, + self.config.nhead, + self.config.attn_prob_dropout_ratio, + self.config.hidden_dropout_ratio, + self.config.norm_first, + self.pg, + ) + + hs = self.config.hidden_size + + self.precision = torch.float32 + if self.config.fp16: + self.precision = torch.half + + self.hs_per_rank = int(hs / self.pg_size) + + self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) + self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) + self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) + self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) + self.norm_weight = nn.Parameter(torch.Tensor(hs)) + self.norm_bias = nn.Parameter(torch.Tensor(hs)) + + self.reset_parameters() + torch.cuda.empty_cache() + + def calc_bound(self, w): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) + bound = 1.0 / math.sqrt(fan_in) + return bound + + def reset_parameters(self): + hs = self.config.hidden_size + + nn.init.zeros_(self.out_proj_bias) + + nn.init.ones_(self.norm_weight) + nn.init.zeros_(self.norm_bias) + + if self.pg_size > 1: + rank_in_pg = torch.distributed.get_rank(self.pg) + attn_qkvw_global = torch.empty(hs * 3, hs) + attn_qkvb_global = torch.empty(hs * 3) + nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) + bound = self.calc_bound(attn_qkvw_global) + nn.init.uniform_(attn_qkvb_global, -bound, bound) + + attn_qkvw_global = attn_qkvw_global.cuda() + attn_qkvb_global = attn_qkvb_global.cuda() + torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) + torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) + attn_qkvw_global = attn_qkvw_global.cpu() + attn_qkvb_global = attn_qkvb_global.cpu() + + with torch.no_grad(): + self.in_proj_weight.copy_( + attn_qkvw_global.view(3, hs, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size), :]) + self.in_proj_bias.copy_( + attn_qkvb_global.view(3, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) + + attn_ow_global = torch.empty(hs, hs) + nn.init.xavier_uniform_(attn_ow_global, 1.0) + attn_ow_global = attn_ow_global.cuda() + torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) + attn_ow_global = attn_ow_global.cpu() + with torch.no_grad(): + self.out_proj_weight.copy_(attn_ow_global[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) + + else: + attn_qkvw = self.in_proj_weight.view(-1, hs) + nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) + bound = self.calc_bound(attn_qkvw) + nn.init.uniform_(self.in_proj_bias, -bound, bound) + + nn.init.xavier_uniform_(self.out_proj_weight, 1.0) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) + return destination + + def forward(self, hidden_states, encoder_padding_mask): + self.config.training = self.training + self.config.is_grad_enabled = torch.is_grad_enabled() + hidden_states = hidden_states.contiguous() + encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) + + bs, sl, dim = hidden_states.size() + if bs * sl > self.config.max_batch_tokens: + raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") + if sl > self.config.max_seq_len: + raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") + if len(encoder_padding_mask.size()) == 1: + assert bs == 1 and sl == encoder_padding_mask.size(0) + else: + assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) + + output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, + self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, + self.norm_weight, self.norm_bias, self.config) + + return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..e02067d05f4f66de44e64b32ca7b8742e82eb512 --- /dev/null +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -0,0 +1,193 @@ +"""This code from NVIDIA Megatron + with some changes. """ + +import enum + +import torch +import torch.nn as nn + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + try: + import colossalai._C.scaled_upper_triang_masked_softmax + except ImportError: + raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') + + scale_t = torch.tensor([scale]) + softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + try: + import colossalai._C.scaled_upper_triang_masked_softmax + except ImportError: + raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') + + softmax_results, scale_t = ctx.saved_tensors + input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, + scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + try: + import colossalai._C.scaled_masked_softmax + except ImportError: + raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') + + scale_t = torch.tensor([scale]) + + softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + try: + import colossalai._C.scaled_masked_softmax + except ImportError: + raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not (self.input_in_fp16 + and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if (self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + try: + import colossalai._C.scaled_masked_softmax + except ImportError: + raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') + + return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57b8fb7b2e996ea0f0336dad1e42ea379d608b15 --- /dev/null +++ b/colossalai/kernel/jit/__init__.py @@ -0,0 +1,8 @@ +from .option import set_jit_fusion_options +from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_gelu import bias_gelu_impl + +__all__ = [ + "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", + "set_jit_fusion_options" +] diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py new file mode 100644 index 0000000000000000000000000000000000000000..3687dde79a08b7f8f192d6516694938828aae659 --- /dev/null +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -0,0 +1,24 @@ +import torch + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..e6da70c40b42cb735ef1c39a6c86297ccd51a5f8 --- /dev/null +++ b/colossalai/kernel/jit/bias_gelu.py @@ -0,0 +1,45 @@ +import torch + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g + + +class GeLUFunction(torch.autograd.Function): + + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py new file mode 100644 index 0000000000000000000000000000000000000000..aa41f57678fc116ac4acef72f52763f5dadabfed --- /dev/null +++ b/colossalai/kernel/jit/option.py @@ -0,0 +1,79 @@ +import torch + +from colossalai.nn.layer.colossalai_layer import Embedding, Linear +from colossalai.utils import get_current_device + +from .bias_dropout_add import bias_dropout_add_fused_train +from .bias_gelu import bias_gelu_impl + +JIT_OPTIONS_SET = False + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options. + """ + # LSG: the latest pytorch and CUDA versions may not support + # the following jit settings + global JIT_OPTIONS_SET + if JIT_OPTIONS_SET == False: + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + + JIT_OPTIONS_SET = True + + +def warmup_jit_fusion(batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32): + """ Compilie JIT functions before the main training steps """ + + embed = Embedding(vocab_size, hidden_size).to(get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + + x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = embed(x) + y, y_bias = linear_1(x) + z, z_bias = linear_2(y) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + for _ in range(10): + bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias.requires_grad, input_.requires_grad = bias_grad, input_grad + bias_gelu_impl(input_, bias) + + # Warmup fused bias+dropout+add + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + for _ in range(10): + input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + bias_dropout_add_fused_train(input_, bias, residual, dropout_rate) + + torch.cuda.empty_cache() diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97fe4f89ded370c6a71ce277bb33e252912b6f17 --- /dev/null +++ b/colossalai/logging/__init__.py @@ -0,0 +1,38 @@ +import logging +from typing import List, Optional + +from .logger import DistributedLogger + +__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] + + +def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: + """Get logger instance based on name. The DistributedLogger will create singleton instances, + which means that only one logger instance is created per name. + + Args: + name (str): name of the logger, name must be unique + + Returns: + :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. + """ + return DistributedLogger.get_instance(name=name) + + +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None: + """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". + + Args: + include (Optional[List[str]], optional): Loggers whose name in this list will be disabled. + If set to `None`, `exclude` argument will be used. Defaults to None. + exclude (List[str], optional): Loggers whose name not in this list will be disabled. + This argument will be used only when `include` is None. Defaults to ['colossalai']. + """ + if include is None: + filter_func = lambda name: name not in exclude + else: + filter_func = lambda name: name in include + + for log_name in logging.Logger.manager.loggerDict.keys(): + if filter_func(log_name): + logging.getLogger(log_name).setLevel(logging.WARNING) diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..8d50ee41819502c59e00aefb52f8d5b00ff074d5 --- /dev/null +++ b/colossalai/logging/logger.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +import logging +from pathlib import Path +from typing import List, Union + +import colossalai +from colossalai.context.parallel_mode import ParallelMode + + +class DistributedLogger: + """This is a distributed event logger class essentially based on :class:`logging`. + + Args: + name (str): The name of the logger. + + Note: + The parallel_mode used in ``info``, ``warning``, ``debug`` and ``error`` + should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + __instances = dict() + + @staticmethod + def get_instance(name: str): + """Get the unique single logger instance based on name. + + Args: + name (str): The name of the logger. + + Returns: + DistributedLogger: A DistributedLogger object + """ + if name in DistributedLogger.__instances: + return DistributedLogger.__instances[name] + else: + logger = DistributedLogger(name=name) + return logger + + def __init__(self, name): + if name in DistributedLogger.__instances: + raise Exception( + 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') + else: + handler = None + formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + try: + from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) + handler.setFormatter(formatter) + except ImportError: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + self._name = name + self._logger = logging.getLogger(name) + self._logger.setLevel(logging.INFO) + if handler is not None: + self._logger.addHandler(handler) + self._logger.propagate = False + + DistributedLogger.__instances[name] = self + + @staticmethod + def __get_call_info(): + stack = inspect.stack() + + # stack[1] gives previous function ('info' in our case) + # stack[2] gives before previous function and so on + + fn = stack[2][1] + ln = stack[2][2] + func = stack[2][3] + + return fn, ln, func + + @staticmethod + def _check_valid_logging_level(level: str): + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' + + def set_level(self, level: str) -> None: + """Set the logging level + + Args: + level (str): Can only be INFO, DEBUG, WARNING and ERROR. + """ + self._check_valid_logging_level(level) + self._logger.setLevel(getattr(logging, level)) + + def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None: + """Save the logs to file + + Args: + path (A string or pathlib.Path object): The file to save the log. + mode (str): The mode to write log into the file. + level (str): Can only be INFO, DEBUG, WARNING and ERROR. + suffix (str): The suffix string of log's name. + """ + assert isinstance(path, (str, Path)), \ + f'expected argument path to be type str or Path, but got {type(path)}' + self._check_valid_logging_level(level) + + if isinstance(path, str): + path = Path(path) + + # create log directory + path.mkdir(parents=True, exist_ok=True) + + # set the default file name if path is a directory + if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL): + rank = 0 + else: + rank = colossalai.core.global_context.get_global_rank() + + if suffix is not None: + log_file_name = f'rank_{rank}_{suffix}.log' + else: + log_file_name = f'rank_{rank}.log' + path = path.joinpath(log_file_name) + + # add file handler + file_handler = logging.FileHandler(path, mode) + file_handler.setLevel(getattr(logging, level)) + formatter = logging.Formatter(_FORMAT) + file_handler.setFormatter(formatter) + self._logger.addHandler(file_handler) + + def _log(self, + level, + message: str, + parallel_mode: ParallelMode = ParallelMode.GLOBAL, + ranks: List[int] = None) -> None: + if ranks is None: + getattr(self._logger, level)(message) + else: + local_rank = colossalai.core.global_context.get_local_rank(parallel_mode) + if local_rank in ranks: + getattr(self._logger, level)(message) + + def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log an info message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('info', message_prefix, parallel_mode, ranks) + self._log('info', message, parallel_mode, ranks) + + def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log a warning message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('warning', message_prefix, parallel_mode, ranks) + self._log('warning', message, parallel_mode, ranks) + + def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log a debug message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('debug', message_prefix, parallel_mode, ranks) + self._log('debug', message, parallel_mode, ranks) + + def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log an error message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('error', message_prefix, parallel_mode, ranks) + self._log('error', message, parallel_mode, ranks) diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..910ad203180c8dd533ebc7732d26d94a20b72929 --- /dev/null +++ b/colossalai/nn/__init__.py @@ -0,0 +1,6 @@ +from ._ops import * +from .layer import * +from .loss import * +from .lr_scheduler import * +from .metric import * +from .optimizer import * diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4991ad9a2217f904287d438ca37c8c4717a40a67 --- /dev/null +++ b/colossalai/nn/_ops/__init__.py @@ -0,0 +1,9 @@ +from .addmm import colo_addmm +from .batch_norm import colo_batch_norm +from .element_wise import * +from .embedding import colo_embedding +from .embedding_bag import colo_embedding_bag +from .layernorm import colo_layernorm +from .linear import colo_linear +from .loss import colo_cross_entropy +from .view import colo_view diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56bb5f465184516e6cb427f1fb98ee76f517170c --- /dev/null +++ b/colossalai/nn/_ops/_utils.py @@ -0,0 +1,284 @@ +import torch +from typing import Union, Optional, List +from colossalai.tensor import ColoTensor +import torch +import torch.distributed as dist +from colossalai.global_variables import tensor_parallel_env as env + +from colossalai.nn.layer.utils import divide +from colossalai.tensor import ProcessGroup, ColoTensorSpec + +GeneralTensor = Union[ColoTensor, torch.Tensor] +Number = Union[int, float] + + +def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]: + if tensor is not None and not isinstance(tensor, ColoTensor): + tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg)) + return tensor + + +def set_parallel_input(input_parallel: bool): + env.parallel_input_1d = input_parallel + + +def get_parallel_input(): + return env.parallel_input_1d + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + +def _reduce(input_, pg: ProcessGroup): + # skip if only one rank involved + if pg.tp_world_size() == 1: + return input_ + assert input_.device.type == 'cuda' + group = pg.tp_process_group() + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, pg: ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = pg.tp_world_size() + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = pg.tp_local_rank() + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, pg: ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = pg.tp_world_size() + if world_size == 1: + return input_ + + # all gather + rank = pg.tp_local_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + assert input_.device.type == 'cuda' + group = pg.tp_process_group() + torch.distributed.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + Pass the input to the model parallel region. + + Args: + input_: input matrix. + process_group: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.mode = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.mode), None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + process_group: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.mode = process_group + ctx.dim = dim + return _split(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.mode = process_group + ctx.dim = dim + return _gather(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def reduce_grad(input_, process_group): + return _ReduceGrad.apply(input_, process_group) + + +def reduce_input(input_, process_group): + return _ReduceInput.apply(input_, process_group) + + +def split_forward_gather_backward(input_, process_group, dim): + return _SplitForwardGatherBackward.apply(input_, process_group, dim) + + +def gather_forward_split_backward(input_, process_group, dim): + return _GatherForwardSplitBackward.apply(input_, process_group, dim) + + +def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return x + + # TODO: enabling mpi backend to support CPU all_to_all + assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + + shapes = list(x.size()) + shapes[scatter_dim] = shapes[scatter_dim] // world_size + + scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + return torch.cat(gather_list, dim=gather_dim).contiguous() + + +class _DualAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, pg, scatter_dim, gather_dim): + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.pg = pg + return _all_to_all(x, pg, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, grad): + return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None + + +def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): + return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim) + + +### table wise embedding shard + + +def _all_to_all_for_tablewise(x: torch.Tensor, + pg: ProcessGroup, + scatter_strides: List[int], + gather_strides: List[int], + forward=True) -> torch.Tensor: + world_size = pg.tp_world_size() + rank = pg.tp_local_rank() + if world_size == 1: + return x + assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + if forward: + scatter_list = list(x.split(scatter_strides, 0)) + gather_list = [ + torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, device=x.device) + for i in range(world_size) + ] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 1).contiguous() + else: + # split on dim 1, lose contiguity + scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)] + gather_list = [ + torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, device=x.device) + for i in range(world_size) + ] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 0).contiguous() + + +class _DualAllToAllForTablewise(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, pg, scatter_strides, gather_strides): + ctx.pg = pg + ctx.scatter_strides = scatter_strides + ctx.gather_strides = gather_strides + return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True) + + @staticmethod + def backward(ctx, grad): + return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, + forward=False), None, None, None + + +def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides): + return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7e8bef63e7e6751c9ab2f169e5e8b371d3ee59 --- /dev/null +++ b/colossalai/nn/_ops/addmm.py @@ -0,0 +1,86 @@ +import torch +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor +from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec +from ._utils import GeneralTensor, Number, convert_to_colo_tensor +from ._utils import reduce_input, reduce_grad + + +def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, + alpha: Number) -> ColoTensor: + # mat1:S[1] x mat2:S[0] = Output:P + # beta * input + alpha * All-Reduce(Output) = res + + mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group()) + + # Output:P + partial_output = torch.mm(mat1, mat2) + # Reduce(Output) + output = reduce_input(partial_output, mat2.get_process_group()) + # input + assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' + output = beta * input_tensor + alpha * output + output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group())) + return output + + +def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, + alpha: Number) -> ColoTensor: + # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] + compute_spec = mat2.compute_spec + mat1 = mat1.redistribute(ReplicaSpec()) + mat1 = reduce_grad(mat1, mat1.get_process_group()) + + output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) + output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]), + ComputeSpec(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, + alpha: Number) -> ColoTensor: + assert mode in ('row', 'col') + funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol} + return funcs[mode](input_tensor, mat1, mat2, beta, alpha) + + +@colo_op_impl(torch.addmm) +def colo_addmm(input_tensor: GeneralTensor, + mat1: ColoTensor, + mat2: ColoTensor, + beta: Number = 1, + alpha: Number = 1, + *args) -> ColoTensor: + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. + This method computes a linear. + """ + # At least one of the tensor should be ColoTensor + assert isinstance(mat2, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group()) + mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group()) + + # Add communication logic before and after linear call. + ret_tensor = None + if not mat2.has_compute_spec(): # No Model Parallel Applied + assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' + assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' + ret_tensor = ColoTensor.from_torch_tensor( + tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha), + spec=ColoTensorSpec(mat2.get_process_group())) + elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if mat2.is_shard_1drow() and input_tensor.is_replicate(): + mode = 'row' + elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()): + mode = 'col' + else: + raise NotImplementedError + ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha) + else: + raise NotImplementedError + + return ret_tensor diff --git a/colossalai/nn/_ops/batch_norm.py b/colossalai/nn/_ops/batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..54ecc88f420a8d8a2c81aca6ac765e57f5a56cac --- /dev/null +++ b/colossalai/nn/_ops/batch_norm.py @@ -0,0 +1,33 @@ +from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.batch_norm) +def colo_batch_norm( + input: GeneralTensor, + running_mean: Optional[GeneralTensor], + running_var: Optional[GeneralTensor], + weight: Optional[GeneralTensor] = None, + bias: Optional[GeneralTensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +): + assert isinstance(weight, ColoTensor) + running_mean = running_mean.detach() + running_var = running_var.detach() + + input = convert_to_colo_tensor(input, weight.get_process_group()) + bias = convert_to_colo_tensor(bias, weight.get_process_group()) + input = input.redistribute(ReplicaSpec()) + bias = bias.redistribute(ReplicaSpec()) + + output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) + output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group())) + return output diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py new file mode 100644 index 0000000000000000000000000000000000000000..2de51e24a6dd1bd45271a0b8d51372ee5209415d --- /dev/null +++ b/colossalai/nn/_ops/element_wise.py @@ -0,0 +1,250 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor + + +def register_elementwise_op(op): + + @colo_op_impl(op) + def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs): + """ + Handles ``__torch_function__`` dispatch for the elementwise op such + as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. + This method computes on either a normal tensor or a sharded tensor. + """ + if 'inplace' in kwargs: + # TODO(jiaruifang) inplace will cause bugs + input_tensor = input_tensor.clone() + return op(input_tensor, *args, **kwargs) + else: + output = op(input_tensor, *args, **kwargs) + # return output + if isinstance(input_tensor, ColoTensor): + if isinstance(output, str): + return output + if not isinstance(output, torch.Tensor): + raise NotImplementedError + return ColoTensor.from_torch_tensor(output, + spec=ColoTensorSpec(input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) + + +# @colo_op_impl(torch.relu_) +# def elementwise_op(input_tensor): +# torch.relu_(input_tensor.data) +# return input_tensor + +# @colo_op_impl(Tensor.add_) +# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): +# input_tensor = input_tensor.data.add_(*args, **kwargs) +# return input_tensor + +# Tensor op +register_elementwise_op(Tensor.abs) +register_elementwise_op(Tensor.absolute) +register_elementwise_op(Tensor.acos) +register_elementwise_op(Tensor.arccos) +register_elementwise_op(Tensor.angle) +register_elementwise_op(Tensor.asin) +register_elementwise_op(Tensor.arcsin) +register_elementwise_op(Tensor.atan) +register_elementwise_op(Tensor.arctan) +register_elementwise_op(Tensor.all) +register_elementwise_op(Tensor.any) +register_elementwise_op(Tensor.bernoulli) +register_elementwise_op(Tensor.bfloat16) +register_elementwise_op(Tensor.bitwise_not) +register_elementwise_op(Tensor.bool) +register_elementwise_op(Tensor.byte) +register_elementwise_op(Tensor.ceil) +register_elementwise_op(Tensor.char) +register_elementwise_op(Tensor.clamp) +register_elementwise_op(Tensor.clamp_max) +register_elementwise_op(Tensor.clamp_min) +register_elementwise_op(Tensor.clip) +register_elementwise_op(Tensor.clone) +register_elementwise_op(Tensor.contiguous) +register_elementwise_op(Tensor.copysign) +register_elementwise_op(Tensor.cos) +register_elementwise_op(Tensor.cosh) +register_elementwise_op(Tensor.acosh) +register_elementwise_op(Tensor.arccosh) +register_elementwise_op(Tensor.cpu) +register_elementwise_op(Tensor.cuda) +register_elementwise_op(Tensor.deg2rad) +register_elementwise_op(Tensor.detach) +register_elementwise_op(Tensor.digamma) +register_elementwise_op(Tensor.double) +register_elementwise_op(Tensor.erf) +register_elementwise_op(Tensor.erfc) +register_elementwise_op(Tensor.erfinv) +register_elementwise_op(Tensor.exp) +register_elementwise_op(Tensor.expm1) +register_elementwise_op(Tensor.fix) +register_elementwise_op(Tensor.trunc) +register_elementwise_op(Tensor.float) +register_elementwise_op(Tensor.float_power) +register_elementwise_op(Tensor.floor) +register_elementwise_op(Tensor.frac) +register_elementwise_op(Tensor.half) +register_elementwise_op(Tensor.hardshrink) +register_elementwise_op(Tensor.heaviside) +register_elementwise_op(Tensor.i0) +register_elementwise_op(Tensor.int) +register_elementwise_op(Tensor.isfinite) +register_elementwise_op(Tensor.isinf) +register_elementwise_op(Tensor.isposinf) +register_elementwise_op(Tensor.isneginf) +register_elementwise_op(Tensor.isnan) +register_elementwise_op(Tensor.lgamma) +register_elementwise_op(Tensor.log) +register_elementwise_op(Tensor.log10) +register_elementwise_op(Tensor.log1p) +register_elementwise_op(Tensor.log2) +register_elementwise_op(Tensor.logical_not) +register_elementwise_op(Tensor.logit) +register_elementwise_op(Tensor.long) +register_elementwise_op(Tensor.nan_to_num) +register_elementwise_op(Tensor.neg) +register_elementwise_op(Tensor.negative) +register_elementwise_op(Tensor.positive) +register_elementwise_op(Tensor.pow) +register_elementwise_op(Tensor.rad2deg) +register_elementwise_op(Tensor.reciprocal) +register_elementwise_op(Tensor.round) +register_elementwise_op(Tensor.rsqrt) +register_elementwise_op(Tensor.short) +register_elementwise_op(Tensor.sigmoid) +register_elementwise_op(Tensor.sign) +register_elementwise_op(Tensor.signbit) +register_elementwise_op(Tensor.sgn) +register_elementwise_op(Tensor.sin) +register_elementwise_op(Tensor.sinc) +register_elementwise_op(Tensor.sinh) +register_elementwise_op(Tensor.asinh) +register_elementwise_op(Tensor.arcsinh) +register_elementwise_op(Tensor.sqrt) +register_elementwise_op(Tensor.square) +register_elementwise_op(Tensor.to) +register_elementwise_op(Tensor.tan) +register_elementwise_op(Tensor.tanh) +register_elementwise_op(Tensor.atanh) +register_elementwise_op(Tensor.arctanh) +register_elementwise_op(Tensor.type) +register_elementwise_op(Tensor.type_as) + +# torch OP +register_elementwise_op(torch.abs) +register_elementwise_op(torch.absolute) +register_elementwise_op(torch.acos) +register_elementwise_op(torch.arccos) +register_elementwise_op(torch.angle) +register_elementwise_op(torch.asin) +register_elementwise_op(torch.arcsin) +register_elementwise_op(torch.atan) +register_elementwise_op(torch.arctan) +register_elementwise_op(torch.all) +register_elementwise_op(torch.any) +register_elementwise_op(torch.bernoulli) +register_elementwise_op(torch.bitwise_not) +register_elementwise_op(torch.ceil) +register_elementwise_op(torch.clamp) +register_elementwise_op(torch.clamp_max) +register_elementwise_op(torch.clamp_min) +register_elementwise_op(torch.clip) +register_elementwise_op(torch.clone) +register_elementwise_op(torch.copysign) +register_elementwise_op(torch.cos) +register_elementwise_op(torch.cosh) +register_elementwise_op(torch.acosh) +register_elementwise_op(torch.arccosh) +register_elementwise_op(torch.deg2rad) +register_elementwise_op(torch.digamma) +register_elementwise_op(torch.erf) +register_elementwise_op(torch.erfc) +register_elementwise_op(torch.erfinv) +register_elementwise_op(torch.exp) +register_elementwise_op(torch.expm1) +register_elementwise_op(torch.fix) +register_elementwise_op(torch.trunc) +register_elementwise_op(torch.float_power) +register_elementwise_op(torch.floor) +register_elementwise_op(torch.frac) +register_elementwise_op(torch.hardshrink) +register_elementwise_op(torch.heaviside) +register_elementwise_op(torch.i0) +register_elementwise_op(torch.isfinite) +register_elementwise_op(torch.isinf) +register_elementwise_op(torch.isposinf) +register_elementwise_op(torch.isneginf) +register_elementwise_op(torch.isnan) +register_elementwise_op(torch.lgamma) +register_elementwise_op(torch.log) +register_elementwise_op(torch.log10) +register_elementwise_op(torch.log1p) +register_elementwise_op(torch.log2) +register_elementwise_op(torch.logical_not) +register_elementwise_op(torch.logit) +register_elementwise_op(torch.nan_to_num) +register_elementwise_op(torch.neg) +register_elementwise_op(torch.negative) +register_elementwise_op(torch.positive) +register_elementwise_op(torch.pow) +register_elementwise_op(torch.rad2deg) +register_elementwise_op(torch.reciprocal) +register_elementwise_op(torch.round) +register_elementwise_op(torch.rsqrt) +register_elementwise_op(torch.sigmoid) +register_elementwise_op(torch.sign) +register_elementwise_op(torch.signbit) +register_elementwise_op(torch.sgn) +register_elementwise_op(torch.sin) +register_elementwise_op(torch.sinc) +register_elementwise_op(torch.sinh) +register_elementwise_op(torch.asinh) +register_elementwise_op(torch.arcsinh) +register_elementwise_op(torch.sqrt) +register_elementwise_op(torch.square) +register_elementwise_op(torch.tan) +register_elementwise_op(torch.tanh) +register_elementwise_op(torch.atanh) +register_elementwise_op(torch.arctanh) +register_elementwise_op(torch.zeros_like) + +# nn.functional OP +register_elementwise_op(F.threshold) +register_elementwise_op(F.relu) +register_elementwise_op(F.hardtanh) +register_elementwise_op(F.hardswish) +register_elementwise_op(F.relu6) +register_elementwise_op(F.elu) +register_elementwise_op(F.selu) +register_elementwise_op(F.celu) +register_elementwise_op(F.leaky_relu) +register_elementwise_op(F.prelu) +register_elementwise_op(F.rrelu) +register_elementwise_op(F.gelu) +register_elementwise_op(F.logsigmoid) +register_elementwise_op(F.hardshrink) +register_elementwise_op(F.tanhshrink) +register_elementwise_op(F.softsign) +register_elementwise_op(F.softplus) +register_elementwise_op(F.softmin) +register_elementwise_op(F.softmax) +register_elementwise_op(F.softshrink) +register_elementwise_op(F.gumbel_softmax) +register_elementwise_op(F.log_softmax) +register_elementwise_op(F.tanh) +register_elementwise_op(F.sigmoid) +register_elementwise_op(F.hardsigmoid) +register_elementwise_op(F.silu) +register_elementwise_op(F.mish) +# TODO(ver217): dropout handles seed +register_elementwise_op(F.dropout) +register_elementwise_op(F.alpha_dropout) +register_elementwise_op(F.feature_alpha_dropout) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a045f305b5dc72454043298b2a69f114ad50f1e9 --- /dev/null +++ b/colossalai/nn/_ops/embedding.py @@ -0,0 +1,140 @@ +import torch.nn.functional as F +from typing import Optional +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \ + ReplicaSpec +from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input + + +def colo_embedding_1Dcol(input_tensor: ColoTensor, + weight: ColoTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> ColoTensor: + # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) + # Gather splitted lookup table + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + output_parallel = F.embedding(input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]), + ComputeSpec(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + + compute_spec = weight.compute_spec + + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_embedding_1Drow(input_tensor: ColoTensor, + weight: ColoTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> ColoTensor: + # embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim] + # get the index of current segment and mask other segments with 0 + + # get complete input tensor through all-gather + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = weight.get_process_group().tp_local_rank() + num_embeddings_per_partition = weight.size_local(0) + vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition + vocab_end_index = vocab_start_index + num_embeddings_per_partition + + # build the mask. + input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index) + # mask the input. + # TODO(jzy) masked_input may be an activation managed by ColoTensor. + masked_input = input_tensor - vocab_start_index + masked_input[input_mask] = 0 + + partial_output = F.embedding(masked_input, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # Mask the output embedding. + partial_output[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(partial_output, weight.get_process_group()) + output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec())) + return output + + +def colo_embedding_1d(mode: str, + input_tensor: ColoTensor, + weight: ColoTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> ColoTensor: + assert mode in ('row', 'col') + funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol} + return funcs[mode](input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + +@colo_op_impl(F.embedding) +def colo_embedding(input_tensor: GeneralTensor, + weight: GeneralTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. + This method looks up an embedding table. + """ + assert isinstance(weight, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) + + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.is_replicate(), 'Invalid weight spec for native embedding op' + return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.is_shard_1drow(): + mode = 'row' + elif weight.is_shard_1dcol(): + mode = 'col' + else: + raise NotImplementedError + return colo_embedding_1d(mode, + input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + else: + raise NotImplementedError diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8aa8fecb0197fc966cf2ec0c611cafcdb38368 --- /dev/null +++ b/colossalai/nn/_ops/embedding_bag.py @@ -0,0 +1,125 @@ +import torch.nn.functional as F +from typing import Optional +from torch import Tensor +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \ + ShardSpec, ReplicaSpec +from ._utils import GeneralTensor, convert_to_colo_tensor + + +def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, + weight: ColoTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> ColoTensor: + # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) + # Gather splitted lookup table + pg = weight.get_process_group() + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + output_parallel = F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + + if weight.compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_embedding_bag_1d(tp_mode: str, + input_tensor: ColoTensor, + weight: ColoTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> ColoTensor: + assert tp_mode in ('col',) + funcs = {'col': colo_embedding_bag_1Dcol} + return funcs[tp_mode](input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + + +@colo_op_impl(F.embedding_bag) +def colo_embedding_bag(input_tensor: GeneralTensor, + weight: GeneralTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. + This method looks up an embedding table. + """ + assert isinstance(weight, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) + + # Handle differen parallel actions. + + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.is_replicate(), 'Invalid weight spec for native embedding op' + return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.is_shard_1dcol(): + tp_mode = 'col' + else: + raise NotImplementedError + return colo_embedding_bag_1d(tp_mode, + input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + else: + raise NotImplementedError diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b761b84e3ee8aa9dcaf7ef1ba054b6857aa4b49 --- /dev/null +++ b/colossalai/nn/_ops/layernorm.py @@ -0,0 +1,25 @@ +from typing import List, Optional +import torch.nn.functional as F +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.layer_norm) +def colo_layernorm( + input_tensor: GeneralTensor, + normalized_shape: List[int], + weight: Optional[GeneralTensor] = None, + bias: Optional[GeneralTensor] = None, + eps: float = 1e-5, +): + assert isinstance(weight, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) + bias = convert_to_colo_tensor(bias, weight.get_process_group()) + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) + output = ColoTensor.from_torch_tensor(tensor=output, + spec=ColoTensorSpec(pg=input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) + return output diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..8835574de5bb9f2ec25f26a74760a792003fde30 --- /dev/null +++ b/colossalai/nn/_ops/linear.py @@ -0,0 +1,171 @@ +import torch.nn.functional as F +from typing import Optional +from ._utils import GeneralTensor, convert_to_colo_tensor +from colossalai.tensor.op_wrapper import colo_op_impl +from ._utils import reduce_input, reduce_grad +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from copy import deepcopy + + +def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': + # Input:S[1] x Weight:S[0] = Output:P + # All-Reduce(Output) + bias = res + # Input:S[1] + pg = weight.get_process_group() + input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg) + + # Output:P + partial_output = F.linear(input_tensor, weight) + # Reduce(Output) + + output = reduce_input(partial_output, pg) + # Bias + if bias is not None: + assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' + output = output + bias + + output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec())) + return output + + +def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': + # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] + # All-Gather(Output) + # Input:B + compute_spec = weight.compute_spec + input_tensor = input_tensor.redistribute(ReplicaSpec()) + input_parallel = reduce_grad(input_tensor, weight.get_process_group()) + + output_parallel = F.linear(input_parallel, weight, bias) + output = ColoTensor.from_torch_tensor(output_parallel, + spec=ColoTensorSpec(weight.get_process_group(), + ShardSpec([-1], [weight.get_tp_world_size()]), + ComputeSpec(ComputePattern.TP1D))) + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': + assert mode in ('row', 'col') + funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol} + return funcs[mode](input_tensor, weight, bias) + + +# @register_colo_graph(input_pos=[1], param_pos=[2, 3]) +def colo_linear_imp(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. + This method computes a linear. + """ + assert isinstance(weight, ColoTensor) + pg = weight.get_process_group() + assert pg + input_tensor = convert_to_colo_tensor(input_tensor, pg) + bias = convert_to_colo_tensor(bias, pg) + # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) + + # Add communication logic before and after linear call. + ret_tensor = None + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.is_replicate(), 'Invalid weight spec for native Linear op' + assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op' + ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg)) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()): + mode = 'row' + elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()): + mode = 'col' + else: + raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}") + ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias) + else: + raise NotImplementedError + + return ret_tensor + + +def _new_colo_linear_imp(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + """ + A tentative function to compute the distributed linear layer with the latest sharding spec. + This function is subject to future change as the current sharding API is not stable. + """ + # get mesh info + input_sharding_seq = input_tensor.sharding_spec.sharding_sequence + weight_sharding_seq = weight.sharding_spec.sharding_sequence + if bias is not None: + bias_sharding_seq = bias.sharding_spec.sharding_sequence + device_mesh = weight.sharding_spec.device_mesh + pg_axis0 = weight.pg_axis0 + pg_axis1 = weight.pg_axis1 + + # the last dim of input should have the same spec as the first dim of weight + # the weight is transposed, so we look at the second dimension + assert input_sharding_seq[-1] == weight_sharding_seq[1] + + if bias is not None: + assert bias_sharding_seq[0] == weight_sharding_seq[0] + + # compute the output sharding sequence + # as weight is transposed, so we look at the first dimension + output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1] + output_shard_seq = deepcopy(output_shard_seq) + + # TODO: add reduce grad logic + + # handle column and row parallel linear + # by reusing the implementation above + out = F.linear(input_tensor, weight) + + # run all reduce if necessary + last_dim_spec = input_sharding_seq[-1] + if last_dim_spec.is_replica: + pass + elif last_dim_spec.shard_list is not None: + for dim in last_dim_spec.shard_list: + if dim == 0: + reduce_input(out, pg_axis0) + elif dim == 1: + reduce_input(out, pg_axis1) + else: + raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected") + # add bias + if bias is not None: + out += bias + + # convert shard seq to partition dict + output_partition_dict = {} + for index, dim_spec in enumerate(output_shard_seq): + if not dim_spec.is_replica: + if index not in output_partition_dict: + output_partition_dict[index] = [] + output_partition_dict[index].extend(dim_spec.shard_list) + + entire_shape = out.shape + output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict) + ret_tensor = ColoTensor.from_torch_tensor(out) + setattr(ret_tensor, 'sharding_spec', output_sharding_spec) + return ret_tensor + + +def _has_sharding_spec(tensor): + """ + A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is + set as the attribute `sharding_spec` on a tensor. + """ + return hasattr(tensor, 'sharding_spec') + + +@colo_op_impl(F.linear) +def colo_linear(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + if _has_sharding_spec(weight): + return _new_colo_linear_imp(input_tensor, weight, bias) + else: + return colo_linear_imp(input_tensor, weight, bias) diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1e54f662859ceffa3edb66157c187c6e17658142 --- /dev/null +++ b/colossalai/nn/_ops/loss.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F +from typing import Optional +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.cross_entropy) +def colo_cross_entropy(input_tensor: GeneralTensor, + target: GeneralTensor, + weight: Optional[GeneralTensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0): + assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor) + pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor) + weight = convert_to_colo_tensor(weight, pg) + target = convert_to_colo_tensor(target, pg) + input_tensor = convert_to_colo_tensor(input_tensor, pg) + + if input_tensor.is_replicate(): # Input is gathered + assert target.is_replicate() and (weight is None or weight.is_replicate()), \ + "Target tensor and weight tensor both should be complete" + output = F.cross_entropy(input_tensor, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing) + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) + elif input_tensor.has_compute_spec(): # Single Model Parallel Applied + if input_tensor.is_shard_1dcol(): + assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" + assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" + output = VocabParallelCrossEntropyLoss1D()(input_tensor, + target, + process_group=input_tensor.process_group.tp_process_group()) + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) + else: + raise NotImplementedError + else: + raise NotImplementedError diff --git a/colossalai/nn/_ops/view.py b/colossalai/nn/_ops/view.py new file mode 100644 index 0000000000000000000000000000000000000000..fafd23e900ad80856f4e585e7292618da711af1d --- /dev/null +++ b/colossalai/nn/_ops/view.py @@ -0,0 +1,97 @@ +import math +import torch +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec +from typing import Optional, Union + + +def _all_int(my_iter): + return all(isinstance(i, int) for i in my_iter) + + +def _get_valid_shape(shape): + if isinstance(shape, list): + if _all_int(shape): + return tuple(shape) + else: + raise RuntimeError("expects type(int) but finds an other type") + elif isinstance(shape, tuple): + if _all_int(shape): + return shape + else: + return _get_valid_shape(shape[0]) + else: + raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) + + +def _shape_infer(org_sp, tgt_sp): + cnt = 0 + pos = 0 + for idx, dim in enumerate(tgt_sp): + if dim < -1: + raise RuntimeError("invalid shape dimension {}".format(dim)) + elif dim == -1: + cnt += 1 + pos = idx + + if cnt > 1: + raise RuntimeError("only one dimension can be inferred") + + org_prod = math.prod(org_sp) + tgt_prod = math.prod(tgt_sp) + + if cnt == 0: + if org_prod != tgt_prod: + raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) + else: + return tgt_sp + elif org_prod % tgt_prod != 0: + raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) + + infer_dim = -(org_prod // tgt_prod) + return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:] + + +@colo_op_impl(torch.Tensor.view) +def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': + """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. + Changes the shape of the current tensor. + """ + assert isinstance(self, ColoTensor) + # apply original `view` function for replicated colo tensors + if self.is_replicate(): + return self.view(*shape) + + cur_sp = self.size() + org_sp = self.size_global() + # parse the passed arguments + tgt_sp = _get_valid_shape(shape) + # get the correct shape from inference + inf_sp = _shape_infer(org_sp, tgt_sp) + + if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: + new_shape = (cur_sp[0],) + tgt_sp[1:] + res = self.view(*new_shape) + elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: + new_shape = tgt_sp[:-1] + (cur_sp[-1],) + res = self.view(*new_shape) + else: + replicated_t = self.redistribute(dist_spec=ReplicaSpec()) + return ColoTensor.from_torch_tensor( + tensor=replicated_t.view(*shape), + spec=ColoTensorSpec(self.get_process_group())) + + return ColoTensor.from_torch_tensor( + tensor=res, + spec=ColoTensorSpec( + pg=self.get_process_group(), + dist_attr=self.dist_spec)) + + +@colo_op_impl(torch.Tensor.size) +def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: + size = self.size_global() + if dim is None: + return size + else: + return size[dim] diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py new file mode 100644 index 0000000000000000000000000000000000000000..559b7038fc352a0ccbea22403cd4b1284bed42e0 --- /dev/null +++ b/colossalai/nn/init.py @@ -0,0 +1,252 @@ +import math +import warnings + +from torch import Tensor +import torch.nn as nn + + +def zeros_(): + """Return the initializer filling the input Tensor with the scalar zeros""" + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.zeros_(tensor) + + return initializer + + +def ones_(): + """Return the initializer filling the input Tensor with the scalar ones""" + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.ones_(tensor) + + return initializer + + +def uniform_(a: float = 0., b: float = 1.): + r"""Return the initializer filling the input Tensor with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + + Args: + a (float): the lower bound of the uniform distribution. Defaults 0.0. + b (float): the upper bound of the uniform distribution. Defaults 1.0. + """ + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.uniform_(tensor, a, b) + + return initializer + + +def normal_(mean: float = 0., std: float = 1.): + r"""Return the initializer filling the input Tensor with values drawn from the normal distribution + + .. math:: + \mathcal{N}(\text{mean}, \text{std}^2) + + Args: + mean (float): the mean of the normal distribution. Defaults 0.0. + std (float): the standard deviation of the normal distribution. Defaults 1.0. + """ + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.normal_(tensor, mean, std) + + return initializer + + +def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): + r"""Return the initializer filling the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + mean (float): the mean of the normal distribution. Defaults 0.0. + std (float): the standard deviation of the normal distribution. Defaults 1.0. + a (float): the minimum cutoff value. Defaults -2.0. + b (float): the maximum cutoff value. Defaults 2.0. + """ + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.trunc_normal_(tensor, mean, std, a, b) + + return initializer + + +def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}} + + Also known as 'He initialization'. + + Args: + a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``). + mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity (str, optional): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + + if mode == 'fan_in': + assert fan_in is not None, 'Fan_in is not provided.' + fan = fan_in + elif mode == 'fan_out': + assert fan_out is not None, 'Fan_out is not provided.' + fan = fan_out + else: + raise ValueError(f'Invalid initialization mode \'{mode}\'') + + std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) + bound = math.sqrt(3.) * std + return nn.init.uniform_(tensor, -bound, bound) + + return initializer + + +def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + normal distribution. The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}} + + Also known as 'He initialization'. + + Args: + a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``). + mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity (str, optional): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + + if mode == 'fan_in': + assert fan_in is not None, 'Fan_in is not provided.' + fan = fan_in + elif mode == 'fan_out': + assert fan_out is not None, 'Fan_out is not provided.' + fan = fan_out + else: + raise ValueError(f'Invalid initialization mode \'{mode}\'') + + std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) + return nn.init.normal_(tensor, 0, std) + + return initializer + + +def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform + distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} + + Also known as 'Glorot initialization'. + + Args: + a (float, optional): an optional scaling factor used to calculate uniform + bounds from standard deviation. Defaults ``math.sqrt(3.)``. + scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. + gain (float, optional): an optional scaling factor. Defaults 1.0. + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + fan = fan_in + if fan_out is not None: + fan += fan_out + + std = gain * math.sqrt(scale / float(fan)) + bound = a * std + return nn.init.uniform_(tensor, -bound, bound) + + return initializer + + +def xavier_normal_(scale: float = 2., gain: float = 1.): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal + distribution. The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} + + Also known as 'Glorot initialization'. + + Args: + scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. + gain (float, optional): an optional scaling factor. Defaults 1.0. + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + fan = fan_in + if fan_out is not None: + fan += fan_out + + std = gain * math.sqrt(scale / float(fan)) + + return nn.init.normal_(tensor, 0., std) + + return initializer + + +def lecun_uniform_(): + # adapted from jax.nn.initializers + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + var = 1.0 / fan_in + bound = math.sqrt(3 * var) + return nn.init.uniform_(tensor, -bound, bound) + + return initializer + + +def lecun_normal_(): + # adapted from jax.nn.initializers + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + std = math.sqrt(1.0 / fan_in) + return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) + + return initializer diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b705632f80407ceda019f4b905c724d1b2254f66 --- /dev/null +++ b/colossalai/nn/layer/__init__.py @@ -0,0 +1,10 @@ +from .colossalai_layer import * +from .parallel_1d import * +from .parallel_2d import * +from .parallel_2p5d import * +from .parallel_3d import * +from .parallel_sequence import * +from .moe import * +from .utils import * +from .vanilla import * +from .wrapper import * diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..c85f53cc44c3660e39a1cf2995ca3e44a70b4c04 --- /dev/null +++ b/colossalai/nn/layer/base_layer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from contextlib import contextmanager + + +class ParallelLayer(nn.Module): + global_state_dict: bool = True + + def __init__(self): + super().__init__() + self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank( + ParallelMode.DATA) + self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size( + ParallelMode.DATA) + + self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank( + ParallelMode.TENSOR) + self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size( + ParallelMode.TENSOR) + + self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + + def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + return super()._save_to_state_dict(destination, prefix, keep_vars) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + if self.global_state_dict: + if gpc.get_local_rank(ParallelMode.TENSOR) != 0: + missing_keys.clear() + unexpected_keys.clear() + return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if self.global_state_dict: + return self._save_to_global_state_dict(destination, prefix, keep_vars) + return super()._save_to_state_dict(destination, prefix, keep_vars) + + @classmethod + @contextmanager + def use_local_state_dict(cls): + try: + cls.global_state_dict = False + yield + finally: + cls.global_state_dict = True diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/nn/layer/colossalai_layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed743820ddbc87e18bd801fc93025e3f69d47797 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/__init__.py @@ -0,0 +1,7 @@ +from ._utils import partition_batch +from .dropout import Dropout +from .embedding import Embedding, PatchEmbedding +from .linear import Classifier, Linear +from .normalization import LayerNorm + +__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..878fe9a68644c392a5fed11f2a76900a75629634 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -0,0 +1,38 @@ +import torch.nn as nn +from torch import Tensor + +from ..parallel_2d._operation import split_batch_2d +from ..parallel_2p5d._operation import split_batch_2p5d +from ..parallel_3d._operation import split_batch_3d +from ..utils import get_tensor_parallel_mode + +_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} + + +def partition_batch(input_) -> Tensor: + tensor_parallel_mode = get_tensor_parallel_mode() + if tensor_parallel_mode in _parallel_split_batch: + if isinstance(input_, dict): + return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} + else: + return _parallel_split_batch[tensor_parallel_mode](input_) + else: + return input_ + + +class ColossalaiModule(nn.Module): + + def __init__(self, module: nn.Module, **kwargs): + super().__init__() + # copy values + self.__dict__ = module.__dict__.copy() + # copy methods + for name, attr in module.__class__.__dict__.items(): + if name not in ['__init__', 'forward'] and callable(attr): + setattr(self, name, getattr(module, name)) + self._forward_func = module.forward + for k, v in kwargs.items(): + setattr(self, k, v) + + def forward(self, *args): + return self._forward_func(*args) diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..30d3f295372324995f8a53de89bacd025c8c49be --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/dropout.py @@ -0,0 +1,30 @@ +import torch.nn as nn +from colossalai.context import ParallelMode, seed + +from ..parallel_1d import * +from ..utils import get_tensor_parallel_mode +from ._utils import ColossalaiModule + + +class Dropout(ColossalaiModule): + """Dropout layer of colossalai. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == "1d": + drop = Dropout1D(p, inplace) + else: + drop = nn.Dropout(p, inplace) + super().__init__(drop, tensor_parallel=tensor_parallel) + + def forward(self, *args): + if self.tensor_parallel in [None, '1d']: + return self._forward_func(*args) + else: + with seed(ParallelMode.TENSOR): + return self._forward_func(*args) diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0b41f8833117ced6c7058e735e824b38006db752 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/embedding.py @@ -0,0 +1,151 @@ +import math +from typing import Callable + +from colossalai.utils import get_current_device +from torch import dtype, nn + +from ... import init as init +from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D +from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D +from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D +from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaPatchEmbedding +from ._utils import ColossalaiModule + +_parallel_embedding = { + '1d': Embedding1D, + '2d': Embedding2D, + '2.5d': Embedding2p5D, + '3d': Embedding3D, +} + +_vocab_parallel_embedding = { + '1d': VocabParallelEmbedding1D, + '2d': VocabParallelEmbedding2D, + '2.5d': VocabParallelEmbedding2p5D, + '3d': VocabParallelEmbedding3D +} + +_parallel_patchembedding = { + None: VanillaPatchEmbedding, + '1d': PatchEmbedding1D, + '2d': PatchEmbedding2D, + '2.5d': PatchEmbedding2p5D, + '3d': PatchEmbedding3D +} + + +class Embedding(ColossalaiModule): + r"""Embedding for colossalai. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, + *args, + **kwargs) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, + **kwargs).to(dtype).to(get_current_device()) + weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) + elif num_embeddings <= vocab_parallel_limit: + embed = _parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + else: + embed = _vocab_parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + super().__init__(embed) + + +class PatchEmbedding(ColossalaiModule): + """2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_() + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + embed = _parallel_patchembedding[tensor_parallel]( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) + super().__init__(embed) diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0c6e285c1c64ccfc25bf5eca6f636cc8744aea --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -0,0 +1,141 @@ +import inspect +import math +from typing import Callable + +from torch import dtype, nn + +from colossalai.utils import get_current_device + +from ... import init as init +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * +from ._utils import ColossalaiModule + +_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} + +_parallel_classifier = { + None: VanillaClassifier, + '1d': Classifier1D, + '2d': Classifier2D, + '2.5d': Classifier2p5D, + '3d': Classifier3D +} + +_vocab_parallel_classifier = { + '1d': VocabParallelClassifier1D, + '2d': VocabParallelClassifier2D, + '2.5d': VocabParallelClassifier2p5D, + '3d': VocabParallelClassifier3D +} + + +class Linear(ColossalaiModule): + """Linear layer of colossalai. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + Note: ``kwargs`` would contain different parameters when you use different parallelisms. + + The ``kwargs`` should contain parameters below: + :: + + Linear1D: + gather_output: bool (optional, default to be false) + skip_bias_add: bool (optional, default to be false) + Linear2D: + skip_bias_add: bool (optional, default to be false) + Linear2p5D: + skip_bias_add: bool (optional, default to be false) + Linear3D: + None + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + tensor_parallel = get_tensor_parallel_mode() + linear_cls = _parallel_linear[tensor_parallel] + gather_output = kwargs.pop('gather_output', None) + if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs['gather_output'] = gather_output + layer = linear_cls( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + super().__init__(layer) + + +class Classifier(ColossalaiModule): + """Classifier layer of colossalai. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048) -> None: + tensor_parallel = get_tensor_parallel_mode() + if num_classes <= vocab_parallel_limit or tensor_parallel is None: + layer = _parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + else: + layer = _vocab_parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + super().__init__(layer) diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..8211d76ad7f1de6b1aee01335c3f91b239386136 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -0,0 +1,41 @@ +from colossalai.utils import get_current_device +from torch import nn + +from ..parallel_1d import LayerNorm1D +from ..parallel_2d import LayerNorm2D +from ..parallel_2p5d import LayerNorm2p5D +from ..parallel_3d import LayerNorm3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaLayerNorm +from ._utils import ColossalaiModule + +_parallel_layernorm = { + None: VanillaLayerNorm, + "1d": LayerNorm1D, + "2d": LayerNorm2D, + "2.5d": LayerNorm2p5D, + "3d": LayerNorm3D, +} + + +class LayerNorm(ColossalaiModule): + r"""Layer Normalization for colossalai. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + else: + norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) + super().__init__(norm) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f9dcd0b3137b096e903ae50222f2c2d7f1f819 --- /dev/null +++ b/colossalai/nn/layer/moe/__init__.py @@ -0,0 +1,9 @@ +from .experts import Experts, FFNExperts, TPExperts +from .layers import MoeLayer, MoeModule +from .routers import MoeRouter, Top1Router, Top2Router +from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts + +__all__ = [ + 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', + 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' +] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..278cdfbb771253bf2c535cf9576ddf3634c90871 --- /dev/null +++ b/colossalai/nn/layer/moe/_operation.py @@ -0,0 +1,154 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +COL_MOE_KERNEL_FLAG = False +try: + import colossalai._C.moe + + COL_MOE_KERNEL_FLAG = True +except ImportError: + print("If you want to activate cuda mode for MoE, please install with cuda_ext!") + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.unsqueeze(0) + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + dist.all_gather(buffer_list, inputs, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None + + +class ReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.squeeze(0) + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllGather.forward(None, grad_outputs, ctx.comm_grp), None + + +class AllToAll(torch.autograd.Function): + """Dispatches input tensor [e, c, h] to all experts by all_to_all_single + operation in torch.distributed. + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + if dist.get_world_size(group) == 1: + return inputs + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + + +class MoeDispatch(torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, mask, dest_idx, ec): + s = tokens.size(0) + h = tokens.size(1) + + expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + + ctx.save_for_backward(mask, dest_idx) + ctx.s = s + ctx.h = h + ctx.ec = ec + + return expert_input + + @staticmethod + def backward(ctx, output_grad): + mask, dest_idx = ctx.saved_tensors + d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + return d_tokens, None, None, None + + +class MoeCombine(torch.autograd.Function): + + @staticmethod + def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): + assert logits.dtype == torch.float32 + + s = logits.size(0) + e = logits.size(1) + c = ec // e + h = expert_tokens.size(-1) + + fp16_flag = (expert_tokens.dtype == torch.float16) + cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) + output = ctokens.to(torch.float16) if fp16_flag else ctokens + + ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) + ctx.s = s + ctx.e = e + ctx.c = c + ctx.h = h + ctx.fp16_flag = fp16_flag + + return output + + @staticmethod + def backward(ctx, tokens_grad): + expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ + else tokens_grad + cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens + d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, + mask, dest_idx) + d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + + return d_expert, d_logits, None, None, None + + +def moe_cumsum(inputs: Tensor): + dim0 = inputs.size(0) + flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) + if flag and COL_MOE_KERNEL_FLAG: + return colossalai._C.moe.cumsum_sub_one(inputs) + else: + return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c51514311e66331656dc3583f2d7b1a0cf7d48 --- /dev/null +++ b/colossalai/nn/layer/moe/experts.py @@ -0,0 +1,172 @@ +import math + +import torch +import torch.nn as nn +from colossalai.context import ParallelMode, seed +from colossalai.utils import get_current_device +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.zero.init_ctx import no_shard_zero_decrator +from typing import Type + + +class MoeExperts(nn.Module): + """Basic class for experts in MoE. It stores what kind of communication expersts use + to exchange tokens, how many experts in a single GPU and parallel information such as + expert parallel size, data parallel size and their distributed communication groups. + """ + + def __init__(self, comm_name: str, num_experts: int): + super().__init__() + assert comm_name in {"all_to_all", "all_gather"}, \ + "This kind of communication has not been implemented yet.\n Please use Experts build function." + self.comm_name = comm_name + # Get the configuration of experts' deployment and parallel information from moe contex + self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) + + +@no_shard_zero_decrator(is_replicated=False) +class Experts(MoeExperts): + """A wrapper class to create experts. It will create E experts across the + moe model parallel group, where E is the number of experts. Every expert + is a instence of the class, 'expert' in initialization parameters. + + Args: + expert_cls (:class:`torch.nn.Module`): The class of all experts + num_experts (int): The number of experts + expert_args: Args used to initialize experts, the args could be found in corresponding expert class + """ + + def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): + super().__init__("all_to_all", num_experts) + + # Use seed to make every expert different from others + with seed(ParallelMode.TENSOR): + self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) + + # Attach parallel information for all parameters in Experts + for exp in self.experts: + for param in exp.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs: torch.Tensor): + # Split inputs for each expert + expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) + expert_output = [] + + # Get outputs from each expert + for i in range(self.num_local_experts): + expert_output.append(self.experts[i](expert_input[i])) + + # Concatenate all outputs together + output = torch.cat(expert_output, dim=1).contiguous() + return output + + +class FFNExperts(MoeExperts): + """Use torch.bmm to speed up for multiple experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_to_all", num_experts) + + self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + for param in self.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, el, c, h] + + el = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(el, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + with seed(ParallelMode.TENSOR): + outputs = self.drop(out_model) # outputs [el, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs + + +class TPExperts(MoeExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divied by maximum expert parallel size or + maximum expert parallel size can't be divied by the number of experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_gather", MOE_CONTEXT.max_ep_size) + + assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ + "d_ff should be divied by maximum expert parallel size" + + p_ff = d_ff // MOE_CONTEXT.max_ep_size + + self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + self.w1.__setattr__('moe_info', self.dist_info) + self.w2.__setattr__('moe_info', self.dist_info) + self.b1.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, e, c, h] + + e = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(e, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + outputs = self.drop(out_model) # outputs [e, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..235592b63618727dada1c32b3403e07fed0da26e --- /dev/null +++ b/colossalai/nn/layer/moe/layers.py @@ -0,0 +1,203 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device +from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \ + ReduceScatter, MoeDispatch, MoeCombine +from colossalai.nn.layer.moe.experts import MoeExperts, Experts +from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator +from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router +from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator +from typing import Optional, Type, Tuple + + +@no_shard_zero_decrator(is_replicated=True) +class MoeLayer(nn.Module): + """A MoE layer, that puts its input tensor to its gate and uses the output logits + to router all tokens, is mainly used to exchange all tokens for every expert across + the moe tensor group by all to all comunication. Then it will get the output of all + experts and exchange the output. At last returns the output of the moe system. + + Args: + dim_model (int): Dimension of model. + num_experts (int): The number of experts. + router (MoeRouter): Instance of router used in routing. + experts (MoeExperts): Instance of experts generated by Expert. + """ + + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): + super().__init__() + self.d_model = dim_model + self.num_experts = num_experts + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) + self.router: MoeRouter = router + self.experts: MoeExperts = experts + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.ep_group = experts.dist_info.ep_group + self.ep_size = experts.dist_info.ep_size + self.num_local_experts = experts.num_local_experts + + nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) + + def a2a_process(self, dispatch_data: torch.Tensor): + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + input_shape = expert_input.shape + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output + + def tp_process(self, dispatch_data: torch.Tensor): + expert_in = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + def forward(self, inputs: torch.Tensor) -> Tuple: + # reshape the input tokens + tokens = inputs.reshape(-1, self.d_model) + + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + + if self.use_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) + else: + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # dispatch_data [e, c, h] + if self.experts.comm_name == "all_to_all": + expert_output = self.a2a_process(dispatch_data) + elif self.experts.comm_name == "all_gather": + expert_output = self.tp_process(dispatch_data) + else: + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + if self.use_kernel: + expert_output = expert_output.reshape(-1, self.d_model) + ans = MoeCombine.apply(expert_output, *route_result_list) + else: + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) + + ans = ans.reshape(inputs.shape) + l_aux = self.router.pop_routing_loss() + return ans, l_aux + + +class MoeModule(nn.Module): + """A class for users to create MoE modules in their models. + + Args: + dim_model (int): Hidden dimension of training model + num_experts (int): The number experts + top_k (int, optional): The number of experts for dispatchment of each token + capacity_factor_train (float, optional): Capacity factor in routing during training + capacity_factor_eval (float, optional): Capacity factor in routing during evaluation + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. + 'Jitter' can be found in `Switch Transformer paper`_. + 'Gaussian' can be found in `ViT-MoE paper`_. + drop_tks (bool, optional): Whether drops tokens in evaluation + use_residual (bool, optional): Makes this MoE layer a Residual MoE. + More information can be found in `Microsoft paper`_. + residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE + expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer + expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given + expert_args (optional): The args of expert when no instance is given + + .. _Switch Transformer paper: + https://arxiv.org/abs/2101.03961 + .. _ViT-MoE paper: + https://arxiv.org/abs/2106.05974 + .. _Microsoft paper: + https://arxiv.org/abs/2201.05596 + """ + + def __init__(self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args): + super().__init__() + + noisy_func = None + if noisy_policy is not None: + if noisy_policy == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noisy_policy == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + + if top_k == 1: + moe_router_cls = Top1Router + elif top_k == 2: + moe_router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + + self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.use_residual = use_residual + if use_residual: + if residual_instance is not None: + self.residual_module = residual_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when residual instance is not given" + self.residual_module = expert_cls(**expert_args) + + with no_shard_zero_context(): + self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) + + if expert_instance is not None: + self.experts = expert_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when experts instance is not given" + self.experts = Experts(expert_cls, num_experts, **expert_args) + + self.moe_layer = MoeLayer(dim_model=dim_model, + num_experts=num_experts, + router=self.moe_router, + experts=self.experts) + + def forward(self, inputs: torch.Tensor): + moe_output, l_aux = self.moe_layer(inputs) + + if self.use_residual: + residual_output = self.residual_module(inputs) + combine_coef = self.residual_combine(inputs) + combine_coef = F.softmax(combine_coef, dim=-1) + output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] + else: + output = moe_output + + return output, l_aux diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py new file mode 100644 index 0000000000000000000000000000000000000000..f11d6aa8e95ccca6bb35f9a661588e95afb6ec0e --- /dev/null +++ b/colossalai/nn/layer/moe/routers.py @@ -0,0 +1,226 @@ +import math +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import moe_cumsum +from typing import Callable, Optional +from torch.distributed import ProcessGroup + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = (mask1 + mask2) # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0b9854b4cfa8ea81fe566bff99843fa60b3f50 --- /dev/null +++ b/colossalai/nn/layer/moe/utils.py @@ -0,0 +1,68 @@ +import torch +import torch.nn.functional as F +from colossalai.utils import get_current_device +from colossalai.context.moe_context import MOE_CONTEXT +from .experts import FFNExperts, TPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2353851df665246251bb7ef0d884dd5f961b7aac --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -0,0 +1,7 @@ +from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row, + PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D) + +__all__ = [ + 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', + 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D' +] diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..3943345582758edaa1298ac4d927eb132887e072 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -0,0 +1,96 @@ +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, + bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1212d595635d7c305132edb8c74e01ab165a903b --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + +from ..utils import divide + + +def set_parallel_input(input_parallel: bool): + env.parallel_input_1d = input_parallel + + +def get_parallel_input(): + return env.parallel_input_1d + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + +def _reduce(input_, parallel_mode): + # skip if only one rank involved + if gpc.get_world_size(parallel_mode) == 1: + return input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = gpc.get_local_rank(parallel_mode) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # all gather + rank = gpc.get_local_rank(parallel_mode) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + Pass the input to the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, parallel_mode): + ctx.mode = parallel_mode + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.mode), None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode): + return _reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _gather(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def reduce_grad(input_, parallel_mode): + return _ReduceGrad.apply(input_, parallel_mode) + + +def reduce_input(input_, parallel_mode): + return _ReduceInput.apply(input_, parallel_mode) + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + +def gather_forward_split_backward(input_, parallel_mode, dim): + return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e96abd87ed1095c0c3337f651de01ad3840c4fb6 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -0,0 +1,1040 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from collections import OrderedDict +from typing import Callable, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.utils.cuda import get_current_device + +from ..base_layer import ParallelLayer +from ..colossalai_layer._utils import ColossalaiModule +from ..utils import divide, set_tensor_parallel_attribute_by_partition +from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from ._operation import linear_with_async_comm +from ._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +@LAYERS.register_module +class Linear1D(ColossalaiModule): + r"""Linear layer for 1D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): Whether to call all-gather on output, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + parallel_input = get_parallel_input() + if not parallel_input and not gather_output: + layer = Linear1D_Col(in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + else: + layer = Linear1D_Row(in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + super().__init__(layer) + + +@LAYERS.register_module +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) + + +@LAYERS.register_module +class Classifier1D(ParallelLayer): + r"""RowLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = False + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +@LAYERS.register_module +class VocabParallelClassifier1D(ParallelLayer): + r"""ColLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.gather_output = gather_output + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +@LAYERS.register_module +class Linear1D_Col(ParallelLayer): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to Fals + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + is_parallel_output = not self.gather_output + set_parallel_input(is_parallel_output) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +@LAYERS.register_module +class Linear1D_Row(ParallelLayer): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to Fals + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=gpc.get_group(ParallelMode.PARALLEL_1D), + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +@LAYERS.register_module +class Embedding1D(ParallelLayer): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +@LAYERS.register_module +class Dropout1D(ParallelLayer): + """Dropout layer of 1D parallelism. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.parallel_input = get_parallel_input() + self.p = p + self.inplace = inplace + + def forward(self, input_: Tensor) -> Tensor: + if self.parallel_input: + with seed(ParallelMode.TENSOR): + output = F.dropout(input_, self.p, self.training, self.inplace) + else: + output = F.dropout(input_, self.p, self.training, self.inplace) + return output + + +@LAYERS.register_module +class PatchEmbedding1D(ColossalaiModule): + """ + 2D Image to Patch Embedding + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param in_chans: number of channels of input image + :type in_chans: int + :param embed_size: size of embedding + :type embed_size: int + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + :param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + :param position_embed_initializer: The intializer of position embedding, defaults to zero + :type position_embed_initializer: typing.Callable, optional + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + embed = VanillaPatchEmbedding(img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer) + super().__init__(embed) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + for key in param_keys: + param = state_dict.pop(key, None) + if param is not None: + local_state[key] = param + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5562d1a700361c23bd8238848b5141c05c9b25aa --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/__init__.py @@ -0,0 +1,8 @@ +from ._operation import reduce_by_batch_2d, split_batch_2d +from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D, + VocabParallelEmbedding2D) + +__all__ = [ + 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', + 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' +] diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..306577dbd9333987bb181d70ef21fcff1d548b7c --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -0,0 +1,849 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter) +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.global_variables import tensor_parallel_env as env + + +def matmul_2d( + a, + b, + summa_dim, + out_shape, + row_rank=None, + col_rank=None, + row_parallel_mode=ParallelMode.PARALLEL_2D_ROW, + col_parallel_mode=ParallelMode.PARALLEL_2D_COL, +): + r"""Matrix multiplication for 2D parallelism. + + Args: + a (:class:`torch.tensor`): matrix :math:`A`. + b (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): + row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW. + col_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): + column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. + + Returns: + :class:`torch.tensor`: :math:`C = AB`. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if row_rank is None: + row_rank = gpc.get_local_rank(col_parallel_mode) + if col_rank is None: + col_rank = gpc.get_local_rank(row_parallel_mode) + + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = summa_dim**2 + return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + +class _Classifier2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias: Optional[Tensor], + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather(B, -1, col_parallel_mode) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce(C, row_parallel_mode) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) + B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) + B_grad = B_grad.reshape(ctx.B_shape) + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None + + return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None + + +def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...], + row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + r"""2D parallel classifier. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + bias (:class:`torch.tensor`, optional): matrix of bias. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + +class Matmul_AB_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + # A: [b / q, s, h / q] -> [(b * s) / q, h / q] + # B: [h / q, s / q] + # C: [b / q, s, s / q] -> [(b * s) / q, s / q] + + assert A.shape[-1] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[-1]) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += summa_dim + + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ABT_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB^T` + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + + assert A.shape[-1] == B.shape[-1], \ + 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[0]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += summa_dim + src_c += 1 + + for op in opr: + op.wait() + + if summa_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ATB_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = A^TB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + + assert A.shape[-2] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[-1], B.shape[-1]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += summa_dim + + for op in opr: + op.wait() + + if summa_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class _Add_Bias_2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + input_: Tensor, + bias: Tensor, + output_size_per_partition: int, + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + bias_temp = all_gather(bias, -1, col_parallel_mode) + + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.bias = skip_bias_add + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + if skip_bias_add: + return bias_temp + else: + output = input_ + bias_temp + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + col_parallel_mode = ctx.col_parallel_mode + + if ctx.bias: + grad = reduce_scatter(output_grad, -1, col_parallel_mode) + return None, grad, None, None, None, None, None, None, None, None, None, None + else: + reduce_dim = tuple(range(output_grad.ndim - 1)) + reduce = torch.sum(output_grad, dim=reduce_dim) + grad = reduce_scatter(reduce, -1, col_parallel_mode) + return output_grad, grad, None, None, None, None, None, None, None, None, None, None + + +def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int, + row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + r"""Matrix add bias: :math:`C = A + b`. + + Args: + input_ (:class:`torch.tensor`): matrix :math:`A`. + bias (:class:`torch.tensor`): matrix :math:`B`. + output_size_per_partition (int): size of output per partition. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + skip_bias_add (bool): + If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + +class _Layernorm_2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode) -> Tensor: + input_ = input_ - E_x + # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) + ctx.normalized_shape = hidden_size + output = input_ * Var_x + ctx.save_for_backward(output, Var_x) + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + row_parallel_mode = ctx.row_parallel_mode + col_parallel_mode = ctx.col_parallel_mode + x, Var_x = ctx.saved_tensors + # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x + output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_sum, group=gpc.get_group(row_parallel_mode)) + output_grad_sum /= ctx.normalized_shape + + output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode)) + output_grad_mul_x_sum /= ctx.normalized_shape + + input_grad = output_grad.clone() + input_grad -= x * output_grad_mul_x_sum + input_grad -= output_grad_sum + input_grad *= Var_x + + return input_grad, None, None, None, None, None + + +def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode) -> Tensor: + r"""Layernorm. + + Args: + input_ (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode) + + +class _AllGatherTensor2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.parallel_mode = parallel_mode + + outputs = all_gather(inputs, dim, parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""All gather the tensor of 2D parallelism. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to gather. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _AllGatherTensor2D.apply(tensor, dim, parallel_mode) + + +def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: + """Splits 2D tensor in specified dimension across cols. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + + Returns: + :class:`torch.tensor`: The tensor has been split. + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + + if world_size <= 1: + return input_ + + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), + dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() + + +class _ReduceTensor2D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceTensor2D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to reduce. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + + return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch2D(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + return output / reduce_size + return output + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None + + +def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_ (:class:`torch.tensor`): input matrix. + reduce_mean (bool, optional): + If set to ``True``, it will divide the output by column parallel size, default to False. + """ + return _ReduceByBatch2D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2d/_utils.py b/colossalai/nn/layer/parallel_2d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..012fec41c80231165ceb92e57e2f449e61fdb8b2 --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/_utils.py @@ -0,0 +1,20 @@ +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + + +def get_summa_dim_from_env() -> int: + try: + summa_dim = env.summa_dim + assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' + return summa_dim + + except KeyError as e: + raise EnvironmentError('SUMMA_DIM is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') + + +def assert_summa_initialization(): + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \ + 'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a4d2bbbc32f8f13815ea36a7c47268217446d5 --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -0,0 +1,1201 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict +from colossalai.utils.cuda import get_current_device +from torch import Tensor +from torch.nn import Parameter + +from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, + reduce_scatter_tensor_2d, split_batch_2d) +from ._utils import assert_summa_initialization, get_summa_dim_from_env + + +@LAYERS.register_module +class Linear2D(ParallelLayer): + r"""Linear layer for 2D parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + + # parallel settings + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.summa_dim) + self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter( + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) + else: + self.register_parameter('bias', None) + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + + if self.bias is not None: + if self.skip_bias_add: + bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output, bias + else: + output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output + else: + return output + + +@LAYERS.register_module +class LayerNorm2D(ParallelLayer): + r"""Layer Normalization for 2D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): + super().__init__() + + # layer norm config + self.normalized_shape = normalized_shape + self.variance_epsilon = eps + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) + + # create parameters + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None + + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + E_x /= self.normalized_shape + + # Var_x in the block below is the sum of input^2 + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + Var_x /= self.normalized_shape + + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + # this time 1/sqrt(Var_x + epsilon) + Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) + + output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL) + scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + if self.bias is not None: + bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) + return output + + +@LAYERS.register_module +class PatchEmbedding2D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_size = embed_size + self.embed_size_per_partition = embed_size // (self.summa_dim**2) + + with seed(ParallelMode.TENSOR): + self.weight = Parameter( + torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype)) + self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = Parameter( + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + self.pos_embed = Parameter( + torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2d(input_) + + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) + bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) + + output = F.conv2d(input_, weight, bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) + pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) + cls_token = cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + pos_embed + + return output + + +@LAYERS.register_module +class Embedding2D(ParallelLayer): + r"""Embedding for 2D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2d(input_) + + weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) + output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding2D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + output_parallel[input_mask, :] = 0. + output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) + return output + + +@LAYERS.register_module +class Classifier2D(ParallelLayer): + r"""Classifier for 2D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.summa_dim**2) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0] + row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) + broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + out_shape = input_.shape[:-1] + (self.num_classes,) + + return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + + +@LAYERS.register_module +class VocabParallelClassifier2D(ParallelLayer): + r"""Vocab parallel classifier layer for 2D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.summa_dim) + self.output_size_per_partition = divide(num_classes, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.output_size_per_partition,) + + output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + + if self.bias is not None: + output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec3b1c4b0b87e8db497b627207ca6f30b1fff49 --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/__init__.py @@ -0,0 +1,8 @@ +from ._operation import reduce_by_batch_2p5d, split_batch_2p5d +from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D, + VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D) + +__all__ = [ + 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', + 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' +] diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0f537cd6d9ca0d6104390a8cdb4634f25d204c --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -0,0 +1,880 @@ +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + +def get_parallel_group(parallel_mode: ParallelMode): + return gpc.get_group(parallel_mode) + + +def get_global_rank(): + return gpc.get_global_rank() + + +def get_parallel_rank(parallel_mode: ParallelMode): + return gpc.get_local_rank(parallel_mode) + + +class _Classifier2p5D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather(B, -1, col_parallel_mode) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce(C, row_parallel_mode) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) + B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) + B_grad = B_grad.reshape(ctx.B_shape) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None + + return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None + + +def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int, + ...], row_rank: int, col_rank: int, + row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int, + pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor: + r"""Classifier. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + bias (:class:`torch.tensor`): matrix of bias. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + +class Matmul_AB_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, + col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] + # B: [h / dq, s / q] + # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] + + assert A.shape[-1] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[-1]) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + tesseract_dim, + group=col_group, + async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += tesseract_dim + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ABT_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB^T`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, + col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + + assert A.shape[-1] == B.shape[-1], \ + 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[0]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + tesseract_dim, + group=col_group, + async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += tesseract_dim + src_c += 1 + + for op in opr: + op.wait() + + if tesseract_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ATB_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = A^TB` + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, + col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int): + + assert A.shape[-2] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[-1], B.shape[-1]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += tesseract_dim + + for op in opr: + op.wait() + + if tesseract_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class _Add_Bias_2p5D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, + row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + if row_rank == 0: + bias_temp = bias.clone() + else: + bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + src_rank = \ + col_rank + dep_rank * tesseract_dim ** 2 + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) + + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.tesseract_dim = tesseract_dim + ctx.col_parallel_mode = col_parallel_mode + ctx.bias = skip_bias_add + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + if skip_bias_add: + return bias_temp + else: + output = input + bias_temp + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + row_rank = ctx.row_rank + col_rank = ctx.col_rank + dep_rank = ctx.dep_rank + tesseract_dim = ctx.tesseract_dim + col_parallel_mode = ctx.col_parallel_mode + data_parallel_rank = ctx.data_parallel_rank + pipeline_parallel_rank = ctx.pipeline_parallel_rank + pipeline_parallel_size = ctx.pipeline_parallel_size + tensor_parallel_size = ctx.tensor_parallel_size + + if ctx.bias: + dst_rank = \ + col_rank + dep_rank * (tesseract_dim ** 2) + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) + if row_rank == 0: + return \ + None, output_grad, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None + else: + grad_tmp = torch.zeros_like(output_grad) + return \ + None, grad_tmp, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None + else: + reduce_dim = tuple(range(output_grad.ndim - 1)) + reduce = torch.sum(output_grad, dim=reduce_dim) + dst_rank = \ + col_rank + dep_rank * (tesseract_dim ** 2) + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) + if row_rank == 0: + return \ + output_grad, reduce, None, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None + else: + reduce_tmp = torch.zeros_like(reduce) + return \ + output_grad, reduce_tmp, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None, None + + +def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, + col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + r"""Matrix add bias: :math:`C = A + b`. + + Args: + input (:class:`torch.tensor`): matrix :math:`A`. + bias (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + output_size_per_partition (int): output size in each partition. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank, + col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + +class _Layernorm2p5D(torch.autograd.Function): + r"""Layernorm. + + Args: + input (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, + row_parallel_mode: ParallelMode) -> Tensor: + input = input - E_x + # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) + ctx.hidden_size = hidden_size + output = input * Var_x + ctx.save_for_backward(output, Var_x) + ctx.row_parallel_mode = row_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + row_parallel_mode = ctx.row_parallel_mode + x, Var_x = ctx.saved_tensors + # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x + with torch.no_grad(): + output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_sum, group=get_parallel_group(row_parallel_mode)) + output_grad_sum /= ctx.hidden_size + + output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode)) + output_grad_mul_x_sum /= ctx.hidden_size + + input_grad = output_grad.clone() + input_grad -= x * output_grad_mul_x_sum + input_grad -= output_grad_sum + input_grad *= Var_x + + return input_grad, None, None, None, None, None, None + + +def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, + row_parallel_mode: ParallelMode) -> Tensor: + r"""Layernorm. + + Args: + input (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode) + + +class _AllGatherTensor2p5D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.col_parallel_mode = col_parallel_mode + + outputs = all_gather(inputs, dim, col_parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + r"""all gather the weight of 2.5D parallelism. + + Args: + inputs (:class:`torch.tensor`): input tensor. + dim (int): dimension of all-gather. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode) + + +class SplitFirst(torch.autograd.Function): + r""" + + Args: + inputs (:class:`torch.tensor`): input tensor. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.tesseract_dim = tesseract_dim + ctx.batch_size = inputs.size(0) + ctx.para_mode = col_parallel_mode + row_rank = gpc.get_local_rank(col_parallel_mode) + + outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank] + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad_shape = (ctx.batch_size,) + output_grad.shape[1:] + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), + output_grad.contiguous(), + group=gpc.get_group(ctx.para_mode)) + return grad, None, None + + +def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: + """Splits 2P5D tensor in specified dimension across cols. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + + Returns: + :class:`torch.tensor`: The tensor has been split. + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + + if world_size <= 1: + return input_ + + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), + dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() + + +class _ReduceTensor2p5D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceTensor2p5D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2p5D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to reduce. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + + return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) + + +class _RreduceByBatch2p5D(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + return output / reduce_size + return output + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None + + +def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_ (:class:`torch.tensor`): input matrix. + reduce_mean (bool, optional): + If set to ``True``, it will divide the output by column parallel size, default to False. + """ + return _RreduceByBatch2p5D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/nn/layer/parallel_2p5d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1478b25de618978ef2c7e060da33edcb47ecff0b --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/_utils.py @@ -0,0 +1,25 @@ +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + + +def get_tesseract_dim_dep_from_env(): + try: + tesseract_dim = env.tesseract_dim + tesseract_dep = env.tesseract_dep + assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' + assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' + return tesseract_dim, tesseract_dep + + except KeyError as e: + raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') + + +def assert_tesseract_initialization(): + assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ + 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \ + 'must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f849cbbe7b0d0f27067941a1c74ccae2ac6fa6f8 --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -0,0 +1,1198 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict) +from colossalai.utils.cuda import get_current_device +from torch import Tensor +from torch.nn import Parameter + +from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, + layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) +from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env + + +@LAYERS.register_module +class Linear2p5D(ParallelLayer): + r"""Linear layer for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter( + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # broadcast in dep groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \ + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP) + # partition in column groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in row groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in row groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in column groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_AB_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + if self.skip_bias_add: + bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output, bias + else: + output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, + self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, + False, self.data_parallel_rank, self.pipeline_parallel_rank, + self.pipeline_parallel_size, self.tensor_parallel_size) + return output + else: + return output + + +@LAYERS.register_module +class LayerNorm2p5D(ParallelLayer): + r"""Layer Normalization for 2.5D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): + super().__init__() + + # layer norm config + self.normalized_shape = normalized_shape + self.variance_epsilon = eps + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * + + # create parameters + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None + + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + E_x /= self.normalized_shape + + # Var_x in the block below is the sum of input^2 + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + Var_x /= self.normalized_shape + + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + # this time 1/sqrt(Var_x + epsilon) + Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) + + output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) + scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + if self.bias is not None: + bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) + return output + + +@LAYERS.register_module +class PatchEmbedding2p5D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_size = embed_size + self.embed_size_per_partition = embed_size // self.tesseract_dim**2 + + with seed(ParallelMode.TENSOR): + self.weight = Parameter( + torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype)) + self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = Parameter( + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + self.pos_embed = Parameter( + torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2p5d(input_, 0) + + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) + bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) + + output = F.conv2d(input_, weight, bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) + pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) + cls_token = cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + pos_embed + + return output + + +@LAYERS.register_module +class Embedding2p5D(ParallelLayer): + r"""Embedding for 2.5D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = embedding_dim // self.tesseract_dim**2 + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2p5d(input_, 0) + + weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL) + + output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding2p5D(ParallelLayer): + """Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) + return output + + +@LAYERS.register_module +class Classifier2p5D(ParallelLayer): + r"""Classifier for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_COL)[0] + row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_ROW)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL) + broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + out_shape = input_.shape[:-1] + (self.num_classes,) + + return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, + self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + + +@LAYERS.register_module +class VocabParallelClassifier2p5D(ParallelLayer): + r"""Vocab parallel classifier layer for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_ABT_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/nn/layer/parallel_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae255b449ee7f57a08a3bb596102860bb1b60d3 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/__init__.py @@ -0,0 +1,8 @@ +from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d +from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D, + VocabParallelEmbedding3D) + +__all__ = [ + 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', + 'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D' +] diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..885d06e6d98d2873fc6a98bfe774301b0b4dde98 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +from ._utils import get_parallel_mode_from_env, push_async_grad + + +class _Linear3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + weight_id: int, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, -1, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + input_op.wait() + + return input_grad, weight_grad, None, None, None, None + + +def linear_3d( + input_: Tensor, + weight: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""Linear layer for 3D parallelism. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Linear3D.apply( + input_, + weight, + id(weight), + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _Classifier3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)] + weight = broadcast(weight, src_rank, input_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight.transpose(0, 1)) + output = all_reduce(output, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.src_rank = src_rank + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + weight_grad = torch.matmul( + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) + weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) + if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): + weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + else: + weight_grad = None + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_grad = torch.matmul(output_grad, weight) + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _VocabParallelClassifier3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_op.wait() + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def vocab_parallel_classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D vocab parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _VocabParallelClassifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +@torch.jit.script +def norm_forward(x, mean, sqr_mean, weight, bias, eps): + mu = x - mean + var = sqr_mean - mean**2 + sigma = torch.sqrt(var + eps) + z = mu / sigma + output = weight * z + bias + + return output, mu, sigma + + +@torch.jit.script +def norm_backward(grad, mu, sigma, weight): + # dbias, dweight = grad, grad * mu / sigma + dz = grad * weight + dmu = dz / sigma + dvar = dz * mu * (-0.5) * sigma**(-3) + dmean = -dmu + dvar = torch.sum(dvar, -1, keepdim=True) + dmean = torch.sum(dmean, -1, keepdim=True) + + return dmu, dmean, dvar + + +class _Layernorm3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Tensor, + weight_id: int, + bias_id: int, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.bias_id = bias_id + + sum_ = torch.sum(input_, dim=-1, keepdim=True) + sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True) + mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape + + output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps) + + ctx.save_for_backward(mu, sigma, weight) + + ctx.normalized_shape = normalized_shape + ctx.output_parallel_mode = output_parallel_mode + ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode + + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + mu, sigma, weight = ctx.saved_tensors + + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) + weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight) + dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode) + input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape + + return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None + + +def layernorm_3d( + input_: Tensor, + weight: Tensor, + bias: Tensor, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D parallel Layernorm. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Layernorm3D.apply( + input_, + weight, + bias, + id(weight), + id(bias), + normalized_shape, + eps, + output_parallel_mode, + input_x_weight_parallel_mode, + ) + + +def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Splits 3D parallel tensor in specified dimension. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode. + + Returns: + :class:`torch.tensor`: The tensor has been split. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + if tensor.size(dim) <= 1: + return tensor + output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), + dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous() + return output + + +def split_batch_3d(input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: + r"""Splits 3D tensor in batch. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): weight parallel mode. + + Returns: + :class:`torch.tensor`: The tensor has been split. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + if input_.size(dim) <= 1: + return input_ + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_world_size = gpc.get_world_size(weight_parallel_mode) + input_world_size = gpc.get_world_size(input_parallel_mode) + output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() + output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() + return output + + +class _ReduceTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input + + Args: + tensor (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _ReduceTensor3D.apply(tensor, parallel_mode) + + +class _AllGatherTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + output = all_gather(input_, dim, parallel_mode) + return output + + @staticmethod + def backward(ctx, output_grad): + input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the gradient in backward pass. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to gather. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _AllGatherTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceScatterTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to scatter. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).' + + return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, + input_: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False) -> Tensor: + output = all_reduce(input_, input_parallel_mode) + output = all_reduce(output, weight_parallel_mode) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None, None, None + else: + return output_grad, None, None, None + + +def reduce_by_batch_3d(tensor: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + reduce_mean (bool, optional): If set to ``True``, it will divide the output by + (input parallel size * weight parallel size), default to False. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..364191a79f88450ca8701d96258f8e34b7b9b784 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -0,0 +1,99 @@ +from collections import OrderedDict +from functools import partial + +import torch +from torch import Tensor + +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + + +def get_depth_from_env() -> int: + try: + depth = env.depth_3d + assert depth > 0, 'DEPTH must be greater than zero' + return depth + + except KeyError as e: + raise EnvironmentError('DEPTH is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') + + +def get_parallel_mode_from_env(group): + assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ + f'{group} is not valid for 3D tensor parallelism.' + return getattr(env, group) + + +def swap_in_out_group(): + env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d + env.input_x_weight_group_3d, env.output_x_weight_group_3d = ( + env.output_x_weight_group_3d, + env.input_x_weight_group_3d, + ) + + +def dbg_check_shape(tensor: Tensor, shape: tuple): + rank = gpc.get_global_rank() + if rank == 0: + print(tensor.shape) + assert tensor.shape == shape, \ + '{} does not match {}'.format(tensor.shape, shape) + + +class AsyncGradientBucket(object): + + def __init__(self): + self.bucket = OrderedDict() + + def __len__(self): + return len(self.bucket) + + def push(self, async_op, grad_tensor, param_id): + self.bucket[param_id] = tuple((async_op, grad_tensor)) + return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device) + + def pop(self, param_id): + grad = None + if param_id in self.bucket: + op, grad = self.bucket.pop(param_id) + if op is not None: + op.wait() + return grad + + def synchronize(self, params): + for p in params: + i = id(p) + if i in self.bucket: + op, grad = self.bucket.pop(i) + if op is not None: + op.wait() + p.grad.add_(grad) + + +_async_grad_bucket = AsyncGradientBucket() + + +def push_async_grad(op, grad, param_id): + return _async_grad_bucket.push(op, grad, param_id) + + +def pop_async_grad(param_id): + return _async_grad_bucket.pop(param_id) + + +def _async_grad_hook(grad, param_id): + grad.add_(pop_async_grad(param_id)) + return grad + + +def register_async_grad_hook(param): + param.register_hook(partial(_async_grad_hook, param_id=id(param))) + + +def synchronize(params=list()): + _async_grad_bucket.synchronize(params) + torch.cuda.default_stream().synchronize() + if len(_async_grad_bucket) > 0: + raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.") diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1db68000dbc3b0ef911662be9efaa682405269 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -0,0 +1,1218 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + +from colossalai.communication import all_reduce, broadcast +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.utils.cuda import get_current_device + +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import ( + all_gather_tensor_3d, + classifier_3d, + layernorm_3d, + linear_3d, + reduce_scatter_tensor_3d, + split_batch_3d, + split_tensor_3d, + vocab_parallel_classifier_3d, +) +from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group + + +@LAYERS.register_module +class LayerNorm3D(ParallelLayer): + r"""Layer Normalization for 3D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): + + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.normalized_shape = normalized_shape + self.normalized_shape_per_partition = divide(normalized_shape, self.depth) + + self.weight = Parameter( + torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + if bias: + self.bias = Parameter( + torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + self.variance_epsilon = eps + self.reset_parameters() + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self) -> None: + init.ones_()(self.weight) + register_async_grad_hook(self.weight) + if self.bias is not None: + init.zeros_()(self.bias) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True, + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return layernorm_3d( + input_, + self.weight, + self.bias, + self.normalized_shape, + self.variance_epsilon, + self.output_parallel_mode, + self.input_x_weight_parallel_mode, + ) + + +@LAYERS.register_module +class Linear3D(ParallelLayer): + r"""Linear layer for 3D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.skip_bias_add = skip_bias_add + self.in_features_per_partition = divide(in_features, self.depth) + self.out_features_per_partition = divide(out_features, self.depth**2) + self.bias_features_per_partition = divide(out_features, self.depth) + + self.weight = Parameter( + torch.empty(self.in_features_per_partition, + self.out_features_per_partition, + device=get_current_device(), + dtype=dtype)) + if bias: + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.out_features + + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + self.bias.register_hook(self._sync_grad_hook) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + output = linear_3d( + input_, + self.weight, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +@LAYERS.register_module +class Classifier3D(ParallelLayer): + r"""Classifier for 3D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode) + + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + +@LAYERS.register_module +class VocabParallelClassifier3D(ParallelLayer): + r"""Vocab parallel classifier layer for 3D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + self.out_features_per_partition = divide(num_classes, self.depth**2) + self.bias_features_per_partition = divide(num_classes, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.out_features_per_partition, + self.in_features_per_partition, + device=get_current_device(), + dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return vocab_parallel_classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + +@LAYERS.register_module +class PatchEmbedding3D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.embed_size = embed_size + embed_size_per_partition = embed_size // self.depth + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty((embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype)) + self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0] + broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode) + + self.weight.register_hook(self._sync_grad_hook) + self.bias.register_hook(self._sync_grad_hook) + self.cls_token.register_hook(self._sync_grad_hook) + self.pos_embed.register_hook(self._sync_grad_hook) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + + return output + + +@LAYERS.register_module +class Embedding3D(ParallelLayer): + r"""Embedding for 3D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, self.depth) + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = nn.Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + broadcast(self.weight, + gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) + self.weight.register_hook(self._sync_grad_hook) + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) + output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding3D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) + self.embed_dim_per_partition = divide(self.embed_dim, self.depth) + vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) + self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) + + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode) + + output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output_parallel[input_mask, :] = 0. + output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) + + return output diff --git a/colossalai/nn/layer/parallel_sequence/__init__.py b/colossalai/nn/layer/parallel_sequence/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa9eed6f34b8ccdcf03935337bc96ba705530d0 --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/__init__.py @@ -0,0 +1,4 @@ +from ._operation import RingQK, RingAV +from .layers import TransformerSelfAttentionRing + +__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/nn/layer/parallel_sequence/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..fc80494224c6d2ec3176f40c47733133a422b88c --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/_operation.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch import distributed as dist + +from colossalai.communication import ring_forward +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd + + +class RingQK(torch.autograd.Function): + """ + Calculate QK in a ring-exchange style + """ + + @staticmethod + @custom_fwd + def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): + # save tensor for backward + ctx.save_for_backward(sub_q, sub_k) + ctx.sub_seq_length = sub_seq_length + + # create local segment of attention score + attention_score = torch.empty(batch_size * num_attention_heads, + sub_seq_length, + sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), + dtype=sub_q.dtype, + device=get_current_device()) + + # compute local QK^T + part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + start_idx = local_rank * sub_seq_length + end_idx = (local_rank + 1) * sub_seq_length + attention_score[:, :, start_idx:end_idx] = part_a + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) + part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) + attention_score[:, :, start_idx:end_idx] = part_a + + return attention_score + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + sub_q, sub_k, = ctx.saved_tensors + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + + # calculate gradient of sub_k + grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) + + dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) + grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length] + grad_k /= local_world_size + + # calculate gradient for sub_q + grad_q = torch.zeros_like( + sub_q, + dtype=sub_q.dtype, + device=get_current_device(), + ) + + # compute with local sub_k + start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) + + grad_q /= local_world_size + + return grad_q, grad_k, None, None, None + + +class RingAV(torch.autograd.Function): + """ + Calculate AV in a ring-exchange style + """ + + @staticmethod + @custom_fwd + def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) + + sub_attention_result = torch.zeros(batch_size * num_attention_heads, + sub_seq_length, + attention_head_size, + device=get_current_device(), + dtype=attention_score.dtype) + + # save tensors for backward + ctx.save_for_backward(attention_score, sub_v) + ctx.sub_seq_length = sub_seq_length + + # compute local AV + part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v) + sub_attention_result += part_av + + # compute AV in ring - all - reduce style + for i in range(local_world_size - 1): + sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) + + # compute QK^T + part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v) + sub_attention_result += part_av + return sub_attention_result + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) + attention_scores, sub_v = ctx.saved_tensors + + # calculate gradient of v + grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output) + dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE)) + grad_v = grad_v[:, local_start_idx:local_end_idx] + grad_v /= local_world_size + + # calculate gradient for attention score + grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + + # compute with local sub_k + grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) + + # compute grad_q + grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) + + return grad_attention_score, grad_v, None, None, None, None diff --git a/colossalai/nn/layer/parallel_sequence/_utils.py b/colossalai/nn/layer/parallel_sequence/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9fad8fab23d2e89d70ef2d82789107db78ebaf08 --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/_utils.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + + +def _calc_incoming_device_range(i, rank, world_size, sub_seq_length): + device_of_incoming_k = (rank - i - 1) % world_size + start_idx = sub_seq_length * device_of_incoming_k + end_idx = sub_seq_length * (device_of_incoming_k + 1) + return start_idx, end_idx + + +def _calc_current_device_range(rank, sub_seq_length): + start_idx = sub_seq_length * rank + end_idx = sub_seq_length * (rank + 1) + return start_idx, end_idx diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d9486217bbc93a9f75cfda809701cafaa15957cc --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/layers.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +import colossalai + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV +from colossalai.registry import LAYERS +from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.kernel import FusedScaleMaskSoftmax +from colossalai.context import seed + + +@LAYERS.register_module +class TransformerSelfAttentionRing(nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + + Args: + hidden_size (int): hidden size. + num_attention_heads (int): number of attention heads. + attention_dropout (float): dropout probability for attention layer. + attention_mask_func (:class:`typing.Callable`): Mask function to be applied. + layer_number (int): number of layers. + + """ + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout, + attention_mask_func, + layer_number, + apply_query_key_layer_scaling: bool = False, + convert_fp16_to_fp32_in_softmax: bool = False, + attn_mask_type=AttnMaskType.padding, + masked_softmax_fusion=True, + fp16=False, + bf16=False): + super().__init__() + self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_mask_func = attention_mask_func + self.layer_number = layer_number + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attn_mask_type = attn_mask_type + assert self.layer_number > 0 + self.attention_dropout = attention_dropout + + if self.apply_query_key_layer_scaling: + self.convert_fp16_to_fp32_in_softmax = True + + assert self.hidden_size % self.num_attention_heads == 0, \ + 'hidden size is not divisible by the number of attention heads' + + self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads + + self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + + # Strided linear layer. + self.query_key_value = _Linear( + hidden_size, + 3 * self.hidden_size, + ) + + self.coeff = None + self.norm_factor = math.sqrt(self.hidden_size) + + if self.apply_query_key_layer_scaling: + self.coeff = layer_number + self.norm_factor *= self.coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion, + self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax, + self.coeff) + + self.attention_dropout = nn.Dropout(attention_dropout) + + # Output. + self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True) + + def forward(self, hidden_states, attention_mask): + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] + sub_seq_length, batch_size, hidden_size = hidden_states.size() + + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads shape change: + # [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)] + mixed_x_layer = self.query_key_value(hidden_states) + + # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] + new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # split into query, key and value + last_dim = mixed_x_layer.dim() - 1 + last_dim_value = mixed_x_layer.size(-1) + assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ + 'cannot be divided into query, key and value' + partition_size = last_dim_value // 3 + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) + + # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), + key_layer.size(0) * self.world_size) + + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] + key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1) + + # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] + attention_scores = RingQK.apply( + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], + batch_size, + self.num_attention_heads, + sub_seq_length) + + attention_scores /= self.norm_factor + + # change view to [batch_size, num_heads, sub_seq_len, seq_len] + attention_scores = attention_scores.view(*output_size) + + # change shape to [batch_size, num_heads, sub_seq_len, seq_len] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with seed(ParallelMode.TENSOR): + attention_probs = self.attention_dropout(attention_probs) + + # context layer shape: [batch_size, num_heads, sub_seq_len, head_size] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + + # change view [sub_seq_len, batch_size * num_heads, head_size] + value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # # change view [b * num_heads, sub_seq_len, seq_len] + attention_probs = attention_probs.view( + attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) + + # matmul: [batch_size * num_heads, sub_seq_len, head_size] + context_layer = RingAV.apply(attention_probs, + value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads, + self.hidden_size_per_attention_head, sub_seq_length) + + # change view [batch_size, num_heads, sub_seq_len, head_size] + context_layer = context_layer.view(*output_size) + + # [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head * + self.num_attention_heads,) + context_layer = context_layer.view(*new_context_layer_shape) + + output, bias = self.dense(context_layer) + + return output, bias + + def __repr__(self): + return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \ + f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \ + f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \ + f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \ + f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})' + + +class _Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): + super(_Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter(torch.empty( + self.output_size, + self.input_size, + )) + nn.init.xavier_normal_(self.weight) + + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ + f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94872b4df6a98034980b6abed19c1f17f65392a2 --- /dev/null +++ b/colossalai/nn/layer/utils/__init__.py @@ -0,0 +1,7 @@ +from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode, + set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple) + +__all__ = [ + 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', + 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' +] diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f2297304fdc939c6a57e9eaa0f52ee268ead46fd --- /dev/null +++ b/colossalai/nn/layer/utils/common.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import collections.abc +from itertools import repeat + +import numpy as np +import torch +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.utils import checkpoint +from torch import Tensor, nn + + +class CheckpointModule(nn.Module): + + def __init__(self, checkpoint: bool = True, offload: bool = False): + super().__init__() + self.checkpoint = checkpoint + self._use_checkpoint = checkpoint + self._offload = offload + + def _forward(self, *args, **kwargs): + raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward') + + def forward(self, *args, **kwargs): + if self._use_checkpoint: + return checkpoint(self._forward, self._offload, *args, **kwargs) + else: + return self._forward(*args, **kwargs) + + def train(self, mode: bool = True): + self._use_checkpoint = self.checkpoint + return super().train(mode=mode) + + def eval(self): + self._use_checkpoint = False + return super().eval() + + +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, 'denominator can not be zero' + assert numerator % denominator == 0, \ + '{} is not divisible by {}'.format(numerator, denominator) + return numerator // denominator + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +def set_tensor_parallel_attribute_by_size(param, size): + setattr(param, IS_TENSOR_PARALLEL, True) + setattr(param, NUM_PARTITIONS, size // np.prod(param.shape)) + + +def set_tensor_parallel_attribute_by_partition(param, num_partitions): + setattr(param, IS_TENSOR_PARALLEL, True) + setattr(param, NUM_PARTITIONS, num_partitions) + + +def get_tensor_parallel_mode(): + return env.mode + + +# From PyTorch internals + + +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/nn/layer/vanilla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d767b8886f53fe5ddb697fd9fb4ed261cfd05c3 --- /dev/null +++ b/colossalai/nn/layer/vanilla/__init__.py @@ -0,0 +1,14 @@ +from .layers import ( + DropPath, + VanillaClassifier, + VanillaLayerNorm, + VanillaLinear, + VanillaPatchEmbedding, + WrappedDropout, + WrappedDropPath, +) + +__all__ = [ + "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath", + "VanillaLinear" +] diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..225aed3916a6dbacaa8876ea994c2d2441713338 --- /dev/null +++ b/colossalai/nn/layer/vanilla/layers.py @@ -0,0 +1,341 @@ +import math +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn as nn +from torch.nn.parameter import Parameter + +from colossalai.context import seed +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.cuda import get_current_device + +from ..utils import to_2tuple + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + Args: + drop_prob (float, optional): probability of dropping path, defaults 0.0. + training (bool, optional): whether in training progress, defaults False. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + + Args: + drop_prob (float, optional): probability of dropping path, defaults None. + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class WrappedDropout(nn.Module): + r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes + some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each + channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of + 1/(1-p) during training. This means that during evaluation the module simply computes an identity function. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): + super().__init__() + if p < 0 or p > 1: + raise ValueError("dropout probability has to be between 0 and 1, " + "but got {}".format(p)) + self.p = p + self.inplace = inplace + if mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def normalfunc(self, inputs): + with seed(self.mode): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def forward(self, inputs): + return self.func(inputs) + + +class WrappedDropPath(nn.Module): + r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Here, it is wrapped with the context of seed manager. + + Args: + p (float, optional): probability of dropping path, defaults 0.0. + mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0., mode=None): + super().__init__() + self.p = p + self.mode = mode + if self.mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return drop_path(inputs, self.p, self.training) + + def normalfunc(self, inputs): + with seed(self.mode): + return drop_path(inputs, self.p, self.training) + + def forward(self, inputs): + return self.func(inputs) + + +@LAYERS.register_module +class VanillaPatchEmbedding(nn.Module): + r""" + 2D Image to Patch Embedding + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def forward(self, input_: Tensor) -> Tensor: + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + return output + + +@LAYERS.register_module +class VanillaClassifier(nn.Module): + r"""Dense linear classifier. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = nn.Parameter( + torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tensor: + return F.linear(input_, self.weight, self.bias) + + +@LAYERS.register_module +class VanillaLayerNorm(nn.Module): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + super().__init__() + + self.normalized_shape = (normalized_shape,) + self.variance_epsilon = eps + + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon) + + +@LAYERS.register_module +class VanillaLinear(nn.Module): + """Linear layer. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add: bool (optional, default to be false). + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + weight_initializer(self.weight, fan_in=in_features, fan_out=out_features) + if self.bias is not None: + bias_initializer(self.bias, fan_in=in_features) + + def forward(self, input: Tensor) -> Tensor: + if not self.skip_bias_add: + return F.linear(input, self.weight, self.bias) + else: + return F.linear(input, self.weight), self.bias diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/nn/layer/wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7d90d887ec6612e351713e508d96b106b767a81 --- /dev/null +++ b/colossalai/nn/layer/wrapper/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_wrapper import PipelineSharedModuleWrapper + +__all__ = ['PipelineSharedModuleWrapper'] diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1d794cc68f15a1885776016601e1b850d933b8 --- /dev/null +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -0,0 +1,46 @@ +import torch.nn as nn +import torch.distributed as dist +from typing import List, Tuple, Union +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class PipelineSharedModuleWrapper: + + def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: + assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' + self.pipeline_ranks = pipeline_ranks + self.group = None + self.ranks_in_group = None + self._init_group() + + def _init_group(self): + world_size = gpc.get_world_size(ParallelMode.GLOBAL) + dp_size = gpc.get_world_size(ParallelMode.DATA) + pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + num_dp_groups = world_size // dp_size + num_pp_stages = num_dp_groups // pp_size + for i in range(dp_size): + for j in range(num_pp_stages): + pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) + sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] + group = dist.new_group(sub_ranks) + if rank in sub_ranks: + self.group = group + self.ranks_in_group = sub_ranks + + def register_module(self, module: nn.Module): + assert self.ranks_in_group is not None,\ + f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + src = self.ranks_in_group[self.pipeline_ranks[0]] + for p in module.parameters(): + setattr(p, 'pipeline_shared_module_pg', self.group) + dist.broadcast(p, src, group=self.group) + + def register_parameter(self, param: nn.Parameter): + assert self.ranks_in_group is not None,\ + f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + src = self.ranks_in_group[self.pipeline_ranks[0]] + setattr(param, 'pipeline_shared_module_pg', self.group) + dist.broadcast(param, src, group=self.group) diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..373e4ec9468bc13317d74c19b5922073a5cb8c0c --- /dev/null +++ b/colossalai/nn/loss/__init__.py @@ -0,0 +1,41 @@ +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn.layer.utils import get_tensor_parallel_mode +from torch import nn +from torch.nn.modules.loss import * +from torch.nn.modules.loss import _Loss + +from .loss_1d import VocabParallelCrossEntropyLoss1D +from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D +from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D +from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D +from .loss_moe import MoeCrossEntropyLoss, MoeLoss + +_parallel_cross_entropy = { + '2d': CrossEntropyLoss2D, + '2.5d': CrossEntropyLoss2p5D, + '3d': CrossEntropyLoss3D, +} + +_vocab_parallel_cross_entropy = { + '1d': VocabParallelCrossEntropyLoss1D, + '2d': VocabParallelCrossEntropyLoss2D, + '2.5d': VocabParallelCrossEntropyLoss2p5D, + '3d': VocabParallelCrossEntropyLoss3D, +} + + +class CrossEntropyLoss(_Loss): + + def __init__(self, reduction: bool = True, *args, **kwargs): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is not None and env.vocab_parallel: + self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + elif tensor_parallel is None or tensor_parallel == '1d': + reduction = 'mean' if reduction else 'none' + self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) + else: + self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + + def forward(self, *args): + return self.loss(*args) diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/nn/loss/loss_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..58d57fdc821df3a9c215c1bdf4c1730653646d4f --- /dev/null +++ b/colossalai/nn/loss/loss_1d.py @@ -0,0 +1,105 @@ +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import LOSSES +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.modules.loss import _Loss + + +class _VocabParallelCrossEntropy1D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, vocab_parallel_logits, targets, process_group): + if process_group is None: + process_group = gpc.get_group(ParallelMode.PARALLEL_1D) + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = dist.get_rank(process_group) + vocab_start_index = partition_vocab_size * rank + vocab_end_index = vocab_start_index + partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) + masked_target = targets.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(targets) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = torch.exp(vocab_parallel_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss1D(_Loss): + """Vocab parallel cross entropy loss for 1D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets, process_group=None): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) + if self.reduction_mean: + loss = loss.mean() + return loss diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..cb12e723c3232446bf2b911730ff31f7274ea6e2 --- /dev/null +++ b/colossalai/nn/loss/loss_2d.py @@ -0,0 +1,156 @@ +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization +from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + + +@LOSSES.register_module +class CrossEntropyLoss2D(_Loss): + r"""Cross entropy loss for 2D parallelism + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, reduction=True, *args, **kwargs): + super().__init__() + assert_summa_initialization() + self.reduction_mean = reduction + self.loss_args = args + self.loss_kwargs = kwargs + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Returns: + float: the loss between logits and targets. + """ + targets = split_batch_2d(targets) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2d(loss, True) + return loss + + +class _VocabParallelCrossEntropy2D(torch.autograd.Function): + ### Modified based on megatron.mpu.cross_entropy ### + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets): + # logits: [b/q, h/q] + # labels: [b/q] + # loss: [b/q] + # vocab_parallel_logits: [b/q, s, v/q] + # target: [b/q, s] + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + # Subtract the maximum value. + # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size = logits.size(-1) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + vocab_start = rank * (vocab_size) + vocab_end = (rank + 1) * (vocab_size) - 1 + + target_mask = (targets < vocab_start) | (targets > vocab_end) + + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange( + start=0, + end=logits.size()[0], + ) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(output_grad.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss2D(_Loss): + """Vocab parallel cross entropy loss for 2D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_batch_2d(targets) + loss = _VocabParallelCrossEntropy2D.apply( + logits, + targets, + ) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2d(loss, True) + return loss diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e3324fc5ff8fe3d28ea25798e33ff9aeb26d25 --- /dev/null +++ b/colossalai/nn/loss/loss_2p5d.py @@ -0,0 +1,149 @@ +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization +from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + + +@LOSSES.register_module +class CrossEntropyLoss2p5D(_Loss): + r"""Cross entropy loss for 2.5D parallelism + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, reduction=True, *args, **kwargs): + super().__init__() + assert_tesseract_initialization() + self.reduction_mean = reduction + self.loss_args = args + self.loss_kwargs = kwargs + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_batch_2p5d(targets) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2p5d(loss, True) + return loss + + +class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): + ### Modified based on megatron.mpu.cross_entropy ### + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets): + # logits: [b/dq, h/q] + # loss: [b/dq] + # targets: [b/dq, h/q] + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + # Subtract the maximum value. + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size = logits.size(-1) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + vocab_start = rank * (vocab_size) + vocab_end = (rank + 1) * (vocab_size) - 1 + + target_mask = (targets < vocab_start) | (targets > vocab_end) + + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange( + start=0, + end=logits.size()[0], + ) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(output_grad.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss2p5D(_Loss): + """ + Vocab parallel cross entropy loss for 2.5D parallelism + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_batch_2p5d(targets) + loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2p5d(loss, True) + + return loss diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/nn/loss/loss_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e76439191fdbc8fc31737e21b5c2b17e6685059b --- /dev/null +++ b/colossalai/nn/loss/loss_3d.py @@ -0,0 +1,147 @@ +import torch +import torch.distributed as dist +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + + +@LOSSES.register_module +class CrossEntropyLoss3D(_Loss): + r"""Cross entropy loss for 3D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, reduction=True, *args, **kwargs): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.reduction_mean = reduction + self.loss_args = args + self.loss_kwargs = kwargs + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) + return loss + + +class _VocabParallelCrossEntropy3D(torch.autograd.Function): + # Adapted from megatron.mpu.cross_entropy + # loss[i] = -logits[i][targets] + log(sum(exp(logits[i]))) + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets, output_parallel_mode): + # logits: [b/q^2, c/q] + # labels: [b/q^2] + # loss: [b/q^2] + logits_max = torch.max(logits, dim=-1)[0] + dist.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(output_parallel_mode)) + # Subtract the maximum value. + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size_per_partition = logits.size()[-1] + rank = gpc.get_local_rank(output_parallel_mode) + vocab_start = rank * vocab_size_per_partition + vocab_end = (rank + 1) * vocab_size_per_partition - 1 + + # loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end + target_mask = (targets < vocab_start) | (targets > vocab_end) + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits = predicted_logits.clone().contiguous().view_as(targets) + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=-1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(output_parallel_mode)) + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + input_grad = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = input_grad.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + input_grad.mul_(output_grad.unsqueeze(dim=-1)) + + return input_grad, None, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss3D(_Loss): + """Vocab parallel cross entropy loss for 2D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + loss = _VocabParallelCrossEntropy3D.apply(logits, targets, self.output_parallel_mode) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) + return loss diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9c576c6a0c61bc2bc0ebf042a20559f9d3491456 --- /dev/null +++ b/colossalai/nn/loss/loss_moe.py @@ -0,0 +1,80 @@ +import torch.nn as nn +from colossalai.registry import LOSSES +from torch.nn.modules.loss import _Loss +from colossalai.context.moe_context import MOE_CONTEXT + + +@LOSSES.register_module +class MoeCrossEntropyLoss(_Loss): + r"""torch.nn.CrossEntropyLoss added with auxiliary loss. + + Args: + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + reduction (str, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): + super().__init__() + self.loss = nn.CrossEntropyLoss(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args): + """ + The ``args`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + main_loss = self.loss(*args) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss + + +@LOSSES.register_module +class MoeLoss(_Loss): + """A wrapper class for any loss module to add with auxiliary loss. + + Args: + aux_weight (float): Weight of auxiliary loss in total loss. + loss_fn (``Callable``): Loss function. + args (list): Args in loss function. + kwargs (dict): Kwargs in loss function + """ + + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): + super().__init__() + self.loss_fn = loss_fn(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args, **kwargs): + """ + The ``args`` and ``kwargs`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Note: + The ``args`` and ``kwargs`` may include different parameters varying with different loss function. + """ + main_loss = self.loss_fn(*args, **kwargs) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34731ee901a0d37a3296e57915df08c38f0b648b --- /dev/null +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -0,0 +1,12 @@ +from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR +from .linear import LinearWarmupLR +from .multistep import MultiStepLR, MultiStepWarmupLR +from .onecycle import OneCycleLR +from .poly import PolynomialLR, PolynomialWarmupLR +from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR + +__all__ = [ + 'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR', + 'MultiStepLR', 'MultiStepWarmupLR', 'OneCycleLR', 'PolynomialLR', 'PolynomialWarmupLR', 'LambdaLR', + 'MultiplicativeLR', 'StepLR', 'ExponentialLR' +] diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py new file mode 100644 index 0000000000000000000000000000000000000000..aab523bef8b30dafc65f60a9475a8b9a70326738 --- /dev/null +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -0,0 +1,121 @@ +from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR + +from colossalai.registry import LR_SCHEDULERS +from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler + + +@LR_SCHEDULERS.register_module +class CosineAnnealingLR(_CosineAnnealingLR): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, eta_min: int = 0, last_epoch: int = -1, **kwargs): + super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class CosineAnnealingWarmupLR(WarmupScheduler): + """Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): + base_scheduler = _CosineAnnealingLR(optimizer, + total_steps - warmup_steps, + eta_min=eta_min, + last_epoch=last_epoch) + super().__init__(optimizer, warmup_steps, base_scheduler) + + +@LR_SCHEDULERS.register_module +class FlatAnnealingLR(DelayerScheduler): + """Flat and cosine annealing learning rate scheduler. The learning rate will be a fixed value before starting decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + pct_start (float, optional): Percent of steps before starting learning rate decay, defaults to -0.72. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs): + if not (0.0 <= pct_start <= 1.0): + raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + flat_steps = int(total_steps * pct_start) + anneal_steps = total_steps - flat_steps + base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps) + super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class FlatAnnealingWarmupLR(WarmupDelayerScheduler): + """Flat and cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be + applied, and then the learning rate will be a fixed value before starting decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + pct_start (float, optional): Percent of steps before starting learning rate decay, defaults to -0.72. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + pct_start: float = 0.72, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs): + if not (0.0 <= pct_start <= 1.0): + raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + flat_steps = int((total_steps - warmup_steps) * pct_start) + anneal_steps = total_steps - warmup_steps - flat_steps + base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps, eta_min=eta_min) + super().__init__(optimizer, warmup_steps, flat_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py new file mode 100644 index 0000000000000000000000000000000000000000..a73ff8ae37ace4f82433c55048dac9cf729389fe --- /dev/null +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -0,0 +1,176 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class DelayerScheduler(_LRScheduler): + """Starts with a flat lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau) + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): + if delay_epochs < 0: + raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + self.delay_epochs = delay_epochs + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} + if isinstance(state_dict['after_scheduler'], _LRScheduler): + state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ + state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() + del state_dict['after_scheduler'] + else: + raise NotImplementedError() + return state_dict + + def get_lr(self): + if self.last_epoch >= self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + + return self.base_lrs + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.delay_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(DelayerScheduler, self).step(epoch) + + +class WarmupScheduler(_LRScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): + self.warmup_epochs = int(warmup_epochs) + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} + if isinstance(state_dict['after_scheduler'], _LRScheduler): + state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ + state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() + del state_dict['after_scheduler'] + else: + raise NotImplementedError() + return state_dict + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + return self.after_scheduler.get_lr() + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) + + +class WarmupDelayerScheduler(_LRScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule + until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1): + if delay_epochs < 0: + raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + if warmup_epochs < 0: + raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') + self.warmup_epochs = warmup_epochs + self.delay_epochs = delay_epochs + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} + if isinstance(state_dict['after_scheduler'], _LRScheduler): + state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ + state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() + del state_dict['after_scheduler'] + else: + raise NotImplementedError() + return state_dict + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs + self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + # reset lr to base_lr + for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): + group['lr'] = base_lr + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + elif self.last_epoch >= self.warmup_epochs: + return self.base_lrs + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..556938b8a60c8ae5eaf116e17710f5253f4bce28 --- /dev/null +++ b/colossalai/nn/lr_scheduler/linear.py @@ -0,0 +1,28 @@ +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class LinearWarmupLR(_LRScheduler): + """Linearly warmup learning rate and then linearly decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0 + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] + else: + return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr + for lr in self.base_lrs] diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..29531a9e385524913b9b6bfb4c89c7ce40b05792 --- /dev/null +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -0,0 +1,62 @@ +from typing import List + +from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR + +from colossalai.registry import LR_SCHEDULERS +from .delayed import WarmupScheduler + + +@LR_SCHEDULERS.register_module +class MultiStepLR(_MultiStepLR): + """Decays the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside + this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + milestones (List[int], optional): List of epoch indices. Must be increasing, defaults to None. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs): + super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class MultiStepWarmupLR(WarmupScheduler): + """Multistep learning rate scheduler with warmup. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + milestones (List[int], optional): List of epoch indices. Must be increasing, defaults to None. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1. + num_steps_per_epoch (int, optional): Number of steps per epoch, defaults to -1. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs): + if len(milestones) == 0: + raise ValueError('milestones cannot be empty') + milestones = [v - warmup_steps for v in milestones if v >= warmup_steps] + base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py new file mode 100644 index 0000000000000000000000000000000000000000..8007fd36008ea01830500a1f3272d0407d2ca3f2 --- /dev/null +++ b/colossalai/nn/lr_scheduler/onecycle.py @@ -0,0 +1,94 @@ +from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR + +from colossalai.registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class OneCycleLR(_OneCycleLR): + r"""Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + This scheduler is not chainable. + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + * A value for total_steps is explicitly provided. + * A number of epochs (epochs) and a number of steps per epoch (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + pct_start (float, optional): + The percentage of the cycle (in number of steps) spent increasing the learning rate, defaults to 0.3. + anneal_strategy (str, optional): {'cos', 'linear'}, Specifies the annealing strategy: + "cos" for cosine annealing, "linear" for linear annealing, defaults to 'cos'. + cycle_momentum (bool, optional): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum', defaults to True. + base_momentum (float, optional): Lower momentum boundaries in the cycle for each parameter group. + Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr', defaults to 0.85. + max_momentum (float, optional): Upper momentum boundaries in the cycle for each parameter group. + Functionally, it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr', defaults to 0.95. + div_factor (float, optional): Determines the initial learning rate via + initial_lr = max_lr/div_factor, defaults to 25.0. + final_div_factor (float, optional): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor, defaults to 10000.0. + last_epoch (int, optional): The index of the last batch. This parameter is used when resuming a training job. + Since `step()` should be invoked after each batch instead of after each epoch, this number represents + the total number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning, defaults to -1 + + The ``kwargs`` for initializing torch.optim.lr_scheduler.OneCycleLR should include parameters below: + :: + + epochs (int, optional, default=None) + steps_per_epoch (int, optional, default=None) + three_phase (bool, optional, default=False) + verbose (bool, optional, default=False) + + More details about kwargs could be found in + `OneCycleLR `_. + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + + def __init__(self, + optimizer, + total_steps: int, + pct_start=0.3, + anneal_strategy='cos', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + last_epoch=-1, + **kwargs): + max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) + super().__init__(optimizer, + max_lrs, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py new file mode 100644 index 0000000000000000000000000000000000000000..16352bc5175ff022f3111bec8503443cdae772b9 --- /dev/null +++ b/colossalai/nn/lr_scheduler/poly.py @@ -0,0 +1,66 @@ +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.registry import LR_SCHEDULERS +from .delayed import WarmupScheduler + + +@LR_SCHEDULERS.register_module +class PolynomialLR(_LRScheduler): + """Polynomial learning rate scheduler. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + end_lr (float, optional): Minimum learning rate, defaults to 0.0001. + power (float, optional): The power of polynomial, defaults to 1.0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs): + if end_lr < 0: + raise ValueError(f'end_lr must >= 0, got {end_lr}') + self.total_steps = total_steps + self.end_lr = end_lr + self.power = power + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + return [(base_lr - self.end_lr) * + ((1 - min(self.last_epoch, self.total_steps) / self.total_steps)**self.power) + self.end_lr + for base_lr in self.base_lrs] + + +@LR_SCHEDULERS.register_module +class PolynomialWarmupLR(WarmupScheduler): + """Polynomial learning rate scheduler with warmup. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + end_lr (float, optional): Minimum learning rate, defaults to 0.0001. + power (float, optional): The power of polynomial, defaults to 1.0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs): + base_scheduler = PolynomialLR(optimizer, total_steps - warmup_steps, end_lr=end_lr, power=power) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..05d2a49c1ea5b0363f82dc61bc3f737c5c2a845a --- /dev/null +++ b/colossalai/nn/lr_scheduler/torch.py @@ -0,0 +1,77 @@ +from torch.optim.lr_scheduler import LambdaLR as _LambdaLR +from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR +from torch.optim.lr_scheduler import StepLR as _StepLR +from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR + +from colossalai.registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class LambdaLR(_LambdaLR): + """Sets the learning rate of each parameter group to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + lr_lambda (Union[``function``, ``list[function]``]): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such functions, + one for each group in optimizer.param_groups, defaults to None. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None: + super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class MultiplicativeLR(_MultiplicativeLR): + """Multiply the learning rate of each parameter group by the factor given + in the specified function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + lr_lambda (Union[``function``, ``list[function]``]): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such functions, + one for each group in optimizer.param_groups, defaults to None. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None: + super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class StepLR(_StepLR): + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. When + last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + step_size (int, optional): Period of learning rate decay, defaults to 1. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None: + super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class ExponentialLR(_ExponentialLR): + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr + + Args: + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Wrapped optimizer. + total_steps (int): Number of total training steps. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 1.0. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, gamma: float = 1.0, last_epoch: int = -1) -> None: + super().__init__(optimizer, gamma, last_epoch=last_epoch) diff --git a/colossalai/nn/metric/__init__.py b/colossalai/nn/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1db3a5eb2284d55d7fe816fb7f8ef6b94b89195 --- /dev/null +++ b/colossalai/nn/metric/__init__.py @@ -0,0 +1,26 @@ +from torch import nn + +from ._utils import calc_acc +from .accuracy_2d import Accuracy2D +from .accuracy_2p5d import Accuracy2p5D +from .accuracy_3d import Accuracy3D +from colossalai.nn.layer.utils import get_tensor_parallel_mode + +_parallel_accuracy = { + '2d': Accuracy2D, + '2.5d': Accuracy2p5D, + '3d': Accuracy3D, +} + + +class Accuracy(nn.Module): + def __init__(self): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel not in _parallel_accuracy: + self.acc = calc_acc + else: + self.acc = _parallel_accuracy[tensor_parallel]() + + def forward(self, *args): + return self.acc(*args) diff --git a/colossalai/nn/metric/_utils.py b/colossalai/nn/metric/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8706ffc101b0e3f5c007d3b08e4ebe0f1aef6e72 --- /dev/null +++ b/colossalai/nn/metric/_utils.py @@ -0,0 +1,7 @@ +import torch + + +def calc_acc(logits, targets): + preds = torch.argmax(logits, dim=-1) + correct = torch.sum(targets == preds) + return correct diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a86832973cfda4ffe89f414d9ae7342191293aac --- /dev/null +++ b/colossalai/nn/metric/accuracy_2d.py @@ -0,0 +1,29 @@ +import torch +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from torch import nn + +from ._utils import calc_acc + + +class Accuracy2D(nn.Module): + """Accuracy for 2D parallelism + """ + + def __init__(self): + super().__init__() + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_batch_2d(targets) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_2d(correct) + return correct diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/nn/metric/accuracy_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..3044da065de136b5e2d5f73b2690893f3dc8e240 --- /dev/null +++ b/colossalai/nn/metric/accuracy_2p5d.py @@ -0,0 +1,29 @@ +import torch +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from torch import nn + +from ._utils import calc_acc + + +class Accuracy2p5D(nn.Module): + """Accuracy for 2p5D parallelism + """ + + def __init__(self): + super().__init__() + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_batch_2p5d(targets) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_2p5d(correct) + return correct diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/nn/metric/accuracy_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..1a917a6df00e9160771ef90ce90ffc2cabcf1407 --- /dev/null +++ b/colossalai/nn/metric/accuracy_3d.py @@ -0,0 +1,33 @@ +import torch +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from torch import nn + +from ._utils import calc_acc + + +class Accuracy3D(nn.Module): + """Accuracy for 3D parallelism + """ + def __init__(self): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) + return correct diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..268e37d57997029f7e06419bc1cbc4ae07d40967 --- /dev/null +++ b/colossalai/nn/optimizer/README.md @@ -0,0 +1,82 @@ +# Colossal-AI Optimization Techniques + +## Introduction + +Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI), +which has been accepted as official tutorials by top conference [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc. + + +[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates +many advanced technologies such as multi-dimensional tensor parallelism, sequence parallelism, heterogeneous memory management, +large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and +quickly deploy large AI model training and inference, reducing large AI model training budgets and scaling down the labor cost of learning and deployment. + +### ๐Ÿš€ Quick Links + +[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) | +[**Paper**](https://arxiv.org/abs/2110.14883) | +[**Documentation**](https://www.colossalai.org/) | +[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) | +[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + + +## Table of Content + +Large transformer models display promising performance on a wide spectrum of AI applications. +Both academia and industry are scaling DL training on larger clusters. However, degrading generalization performance, non-negligible communication overhead, and increasing model size prevent DL researchers and engineers from exploring large-scale AI models. + +We aim to provide a clear sketch of the optimizations for large-scale deep learning with regard to model accuracy and model efficiency. +One way to achieve the goal of maintaining or improving the model accuracy in the large-scale setting while maintaining compute efficiency is to design algorithms that +are less communication and memory hungry. Notably, they are not mutually exclusive but can +be optimized jointly to further speed up training. + +1. Model Accuracy + - Gradient Descent Optimization + - Gradient Descent Variants + - Momentum + - Adaptive Gradient + - Large Batch Training Optimization + - LARS + - LAMB + - Generalization Gap + - Second-Order Optimization + - Hessian-Free + - K-FAC + - Shampoo + +2. Model Accuracy + - Communication Efficiency + - Reduce Volumn of Comm. + - Reduce Frequency of Comm. + - Memory Efficiency + - Mix-Precision Training + - Memory-Efficient Methods, e.g. ZeRO, Gemini, etc. + +Some of the above are still under development. **If you wish to make a contribution to this repository, please read the `Contributing` section below.** + +## Discussion + +Discussion about the Colossal-AI project is always welcomed! We would love to exchange ideas with the community to better help this project grow. +If you think there is a need to discuss anything, you may jump to our [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w). + +If you encounter any problem while running these optimizers, you may want to raise an issue in this repository. + +## Contributing + +This project welcomes constructive ideas and implementations from the community. + +### Update an Optimizer + +If you find that an optimizer is broken (not working) or not user-friendly, you may put up a pull request to this repository and update this optimizer. + +### Add a New Optimizer + +If you wish to add an optimizer for a specific application, please follow the steps below. + +1. create the new optimizer file in the current folder +2. Prepare the corresponding example files in the [Examples](https://github.com/hpcaitech/ColossalAI-Examples) repository to prove effectiveness of the new optimizer +3. Prepare a detailed readme on environment setup, dataset preparation, code execution, etc. in your example folder +4. Update the table of content (last section above) in this readme file + + +If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/). diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06072648beba974c2e6f638bdc5f3d1162141056 --- /dev/null +++ b/colossalai/nn/optimizer/__init__.py @@ -0,0 +1,10 @@ +from .colossalai_optimizer import ColossalaiOptimizer +from .fused_adam import FusedAdam +from .fused_lamb import FusedLAMB +from .fused_sgd import FusedSGD +from .lamb import Lamb +from .lars import Lars +from .cpu_adam import CPUAdam +from .hybrid_adam import HybridAdam + +__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..34f5a9541975aa3029032d0346947c3524cd5a69 --- /dev/null +++ b/colossalai/nn/optimizer/colossalai_optimizer.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from colossalai.utils import clip_grad_norm_fp32 + + +class ColossalaiOptimizer(Optimizer): + + def __init__(self, optim: Optimizer): + self.optim = optim + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def defaults(self): + return self.optim.defaults + + def add_param_group(self, *args, **kwargs): + return self.optim.add_param_group(*args, **kwargs) + + def step(self, *args, **kwargs): + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + self.optim.zero_grad(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.optim.load_state_dict(*args, **kwargs) + + def state_dict(self): + return self.optim.state_dict() + + def backward(self, loss: Tensor): + loss.backward() + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + if max_norm > 0.0: + clip_grad_norm_fp32(model.parameters(), max_norm) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..5b05fecc89f200e01d8f032bc4b386e2f8c49a32 --- /dev/null +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -0,0 +1,172 @@ +import math +from typing import Optional + +import torch + +from colossalai.registry import OPTIMIZERS + +from .nvme_optimizer import NVMeOptimizer + + +@OPTIMIZERS.register_module +class CPUAdam(NVMeOptimizer): + """Implements Adam algorithm. + + Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + But the parameters and gradients should on the same device: + * Parameters on CPU and gradients on CPU is allowed. + * Parameters on GPU and gradients on GPU is allowed. + * Parameters on GPU and gradients on CPU is **not** allowed. + + Requires ColossalAI to be installed via ``pip install .``. + + This version of CPU Adam accelates parameters updating on CPU with SIMD. + Support of AVX2 or AVX512 is required. + + The GPU part is implemented in an naive way. + + CPU Adam also supports the hybrid precision calculation, eg. fp32 parameters and fp16 gradients. + + :class:`colossalai.nn.optimizer.CPUAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + model_params (iterable): iterable of parameters of dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED yet in CPUAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + simd_log (boolean, optional): whether to show if you are using SIMD to + accelerate. (default: False) + nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. + nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. + If it's ``None``, a random temporary directory will be used. Defaults to None. + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + # Number of fp32 shards for per parameter + # Param weight, grad, momentum and variance + num_fp32_shards_per_param = 4 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None): + + default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) + super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) + self.adamw_mode = adamw_mode + try: + import colossalai._C.cpu_optim + except ImportError: + raise ImportError('Please install colossalai from source code to use CPUAdam') + self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, + adamw_mode) + + def torch_adam_update(self, + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + use_adamw=False): + # FIXME(ver217): remove the below line when replace torch adam with fused adam + grad = grad.float() + + if weight_decay != 0: + if use_adamw: + data.mul_(1 - lr * weight_decay) + else: + grad = grad.add(data, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # TODO(jiaruifang) dose not support amsgrad + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + data.addcdiv_(exp_avg, denom, value=-step_size) + + @torch.no_grad() + def step(self, closure=None, div_scale: float = -1): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self._pre_step('exp_avg', 'exp_avg_sq') + for _, group in enumerate(self.param_groups): + for _, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + + target_device = p.device + if len(state) == 0: + state['step'] = 0 + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + self._post_state_init(p) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + if target_device.type == 'cpu': + assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size" + assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" + assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" + self._pre_update(p, 'exp_avg', 'exp_avg_sq') + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], + group['bias_correction'], p.data, p.grad.data, state['exp_avg'], + state['exp_avg_sq'], div_scale) + self._post_update(p, 'exp_avg', 'exp_avg_sq') + elif target_device.type == 'cuda': + assert div_scale == -1, "div_scale should remain default" + assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + + # adam on cuda + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + raise RuntimeError + self._post_step() + return loss diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..064e55a401bb281ce7072fc66aebdbe31e1f21b0 --- /dev/null +++ b/colossalai/nn/optimizer/fused_adam.py @@ -0,0 +1,142 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py +import torch + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + + +@OPTIMIZERS.register_module +class FusedAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Currently GPU-only. Requires ColossalAI to be installed via + ``pip install .``. + + This version of fused Adam implements 2 fusions. + + * Fusion of the Adam update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adamw_mode=True, + weight_decay=0., + amsgrad=False, + set_grad_none=True): + + if amsgrad: + raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) + super(FusedAdam, self).__init__(params, defaults) + self.adamw_mode = 1 if adamw_mode else 0 + self.set_grad_none = set_grad_none + if multi_tensor_applier.available: + import colossalai._C.fused_optim + + # Skip buffer + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self.multi_tensor_adam = colossalai._C.fused_optim.multi_tensor_adam + else: + raise RuntimeError('FusedAdam requires cuda extensions') + + def zero_grad(self, set_to_none=False): + if set_to_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedAdam, self).zero_grad() + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' + ) + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + # create lists for multi-tensor apply + g_l, p_l, m_l, v_l = [], [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + 'FusedAdam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + if p.dtype not in [torch.float16, torch.float32]: + raise RuntimeError('FusedAdam only support fp16 and fp32.') + + g_l.append(p.grad.data) + p_l.append(p.data) + m_l.append(state['exp_avg']) + v_l.append(state['exp_avg_sq']) + + multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], + beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, + group['weight_decay'], div_scale) + + return loss diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..2e33d703292a91835ccf33be82c6fc338b6ad855 --- /dev/null +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -0,0 +1,193 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py +import torch + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + + +@OPTIMIZERS.register_module +class FusedLAMB(torch.optim.Optimizer): + """Implements LAMB algorithm. + + Currently GPU-only. Requires ColossalAI to be installed via + ``pip install .``. + + This version of fused LAMB implements 2 fusions. + + * Fusion of the LAMB update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + :class:`colossalai.nn.optimizer.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer + + :class:`colossalai.nn.optimizer.FusedLAMB` may be used with or without Amp. + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + NOT SUPPORTED now! (default: False) + adam_w_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + amsgrad=False, + adam_w_mode=True, + grad_averaging=True, + set_grad_none=True, + max_grad_norm=1.0, + use_nvlamb=False): + if amsgrad: + raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') + defaults = dict(lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm) + super(FusedLAMB, self).__init__(params, defaults) + if multi_tensor_applier.available: + import colossalai._C.fused_optim + self.multi_tensor_l2norm = colossalai._C.fused_optim.multi_tensor_l2norm + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], + dtype=torch.int, + device=self.param_groups[0]["params"][0].device) + self.multi_tensor_lamb = colossalai._C.fused_optim.multi_tensor_lamb + else: + raise RuntimeError('FusedLAMB requires cuda extensions') + + self.adam_w_mode = 1 if adam_w_mode else 0 + self.set_grad_none = set_grad_none + self.use_nvlamb = use_nvlamb + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedLAMB, self).zero_grad() + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + # create separate grad lists for fp32 and fp16 params + g_all_32, g_all_16 = [], [] + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + if p.dtype == torch.float32: + g_all_32.append(p.grad.data) + elif p.dtype == torch.float16: + g_all_16.append(p.grad.data) + else: + raise RuntimeError('FusedLAMB only support fp16 and fp32.') + + device = self.param_groups[0]["params"][0].device + g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) + # compute grad norm for two lists + if len(g_all_32) > 0: + g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False)[0] + if len(g_all_16) > 0: + g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0] + + # blend two grad norms to get global grad norm + global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, + [[g_norm_32, g_norm_16]], False)[0] + max_grad_norm = self.defaults['max_grad_norm'] + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + # create lists for multi-tensor apply + g_16, p_16, m_16, v_16 = [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + 'FusedLAMB does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + if p.dtype == torch.float16: + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state['exp_avg']) + v_16.append(state['exp_avg_sq']) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state['exp_avg']) + v_32.append(state['exp_avg_sq']) + else: + raise RuntimeError('FusedLAMB only support fp16 and fp32.') + + if (len(g_16) > 0): + multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], + group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, + group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, + max_grad_norm, self.use_nvlamb) + if (len(g_32) > 0): + multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], + group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, + group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, + max_grad_norm, self.use_nvlamb) + + return loss diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..03c3da28d2684a4e99937334a9585a132e3f04fd --- /dev/null +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -0,0 +1,149 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_sgd.py +import torch +from torch.optim.optimizer import Optimizer, required + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + + +@OPTIMIZERS.register_module +class FusedSGD(Optimizer): + r"""Implements stochastic gradient descent (optionally with momentum). + + Currently GPU-only. Requires ColossalAI to be installed via + ``pip install .``. + + This version of fused SGD implements 2 fusions. + + * Fusion of the SGD update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + :class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD`` + + :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + Considering the specific case of Momentum, the update can be written as + + .. math:: + v = \rho * v + g \\ + p = p - lr * v + + where p, g, v and :math:`\rho` denote the parameters, gradient, + velocity, and momentum respectively. + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + + .. math:: + v = \rho * v + lr * g \\ + p = p - v + + The Nesterov version is analogously modified. + """ + + def __init__(self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + wd_after_momentum=False): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(FusedSGD, self).__init__(params, defaults) + + self.wd_after_momentum = wd_after_momentum + + if multi_tensor_applier.available: + import colossalai._C.fused_optim + + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], + dtype=torch.int, + device=self.param_groups[0]["params"][0].device) + self.multi_tensor_sgd = colossalai._C.fused_optim.multi_tensor_sgd + else: + raise RuntimeError('FusedSGD requires cuda extensions') + + def __setstate__(self, state): + super(FusedSGD, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + def get_momentums(self, params): + momentums = [] + first_run = True + for p in params: + param_state = self.state[p] + # torch.optim.SGD initializes momentum in the main loop, we have + # to do it here, and track whether or not we've done so, so that + # momentum application can be skipped in the main kernel. + if 'momentum_buffer' not in param_state: + first_run = True + buf = param_state['momentum_buffer'] = torch.zeros_like(p) + momentums.append(buf) + else: + first_run = False + momentums.append(param_state['momentum_buffer']) + return momentums, first_run + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + # For each group, there are 3 possible combinations we need to consider: + # grad_type, param_to_update_type, momentum_type + # 1. fp16, fp16, fp16 + # 2. fp32, fp32, fp32 + # 3. fp16, fp32, fp32 + g_l, p_l = [], [] + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError('FusedSGD does not support sparse gradients') + g_l.append(p.grad) + p_l.append(p) + m_l, first_run = self.get_momentums(p_l) + multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay, + momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0) + + return loss diff --git a/colossalai/nn/optimizer/gemini_optimizer.py b/colossalai/nn/optimizer/gemini_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..31d1616126006ff568959333a97379328a2ffa44 --- /dev/null +++ b/colossalai/nn/optimizer/gemini_optimizer.py @@ -0,0 +1,15 @@ +from typing import Any + +import torch + +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer + +__all__ = ['GeminiAdamOptimizer'] + + +class GeminiAdamOptimizer(ZeroOptimizer): + + def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: + optimizer = HybridAdam(model.parameters(), **defaults) + super().__init__(optimizer, model, **defaults) diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..a925c3d91d27a1ac36fdf8cefedec2ba2cb9d802 --- /dev/null +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -0,0 +1,151 @@ +from typing import Any, Optional + +import torch + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + +from .nvme_optimizer import NVMeOptimizer + + +@OPTIMIZERS.register_module +class HybridAdam(NVMeOptimizer): + """Implements Adam algorithm. + + Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + But the parameters and gradients should on the same device: + * Parameters on CPU and gradients on CPU is allowed. + * Parameters on GPU and gradients on GPU is allowed. + * Parameters on GPU and gradients on CPU is **not** allowed. + + Requires ColossalAI to be installed via ``pip install .`` + + This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. + + * For parameters updating on CPU, it uses CPUAdam. + * For parameters updating on GPU, it uses FusedAdam. + * Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients. + + :class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + model_params (iterable): iterable of parameters of dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED yet in CPUAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + simd_log (boolean, optional): whether to show if you are using SIMD to + accelerate. (default: False) + nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. + nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. + If it's ``None``, a random temporary directory will be used. Defaults to None. + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + # Number of fp32 shards for per parameter + # Param weight, grad, momentum and variance + num_fp32_shards_per_param = 4 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + **defaults: Any): + + default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) + super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) + self.adamw_mode = adamw_mode + try: + import colossalai._C.cpu_optim + import colossalai._C.fused_optim + except ImportError: + raise ImportError('Please install colossalai from source code to use HybridAdam') + + self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, + adamw_mode) + + self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + @torch.no_grad() + def step(self, closure=None, div_scale: float = -1): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self._pre_step('exp_avg', 'exp_avg_sq') + for _, group in enumerate(self.param_groups): + g_l, p_l, m_l, v_l = [], [], [], [] + group_step = 0 + for _, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + + target_device = p.device + if len(state) == 0: + state['step'] = 0 + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + self._post_state_init(p) + + state['step'] += 1 + group_step = state['step'] + beta1, beta2 = group['betas'] + + if target_device.type == 'cpu': + assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" + assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" + self._pre_update(p, 'exp_avg', 'exp_avg_sq') + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], + group['bias_correction'], p.data, p.grad.data, state['exp_avg'], + state['exp_avg_sq'], div_scale) + self._post_update(p, 'exp_avg', 'exp_avg_sq') + + elif target_device.type == 'cuda': + assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + + # record the state by gruop and update at once + g_l.append(p.grad.data) + p_l.append(p.data) + m_l.append(state['exp_avg']) + v_l.append(state['exp_avg_sq']) + + else: + raise RuntimeError + if len(g_l) > 0: + adamw_mode = 1 if self.adamw_mode else 0 + bias_correction = 1 if group['bias_correction'] else 0 + multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], + group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, + bias_correction, group['weight_decay'], div_scale) + self._post_step() + return loss diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac2109572a443e004c58f91bfeb03f85dbbdc33 --- /dev/null +++ b/colossalai/nn/optimizer/lamb.py @@ -0,0 +1,111 @@ +""" +Adapted from the pytorch-lamb library at https://github.com/cybertronai/pytorch-lamb +""" + +import torch +from torch.optim import Optimizer + +from colossalai.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + # * math.sqrt(bias_correction2) / bias_correction1 + step_size = group['lr'] + + weight_norm = p.data.pow(2).sum().sqrt() + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) + if group['weight_decay'] != 0: + adam_step.add_(p.data, alpha=group['weight_decay']) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..212f66671a0db9f580d99758823fdf78e3e54106 --- /dev/null +++ b/colossalai/nn/optimizer/lars.py @@ -0,0 +1,102 @@ +"""Adapted from https://github.com/NUS-HPC-AI-Lab/LARS-ImageNet-PyTorch/blob/main/lars.py""" + +from typing import Iterable + +import torch +from torch.optim import Optimizer + +from colossalai.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module +class Lars(Optimizer): + r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" + `_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + eeta (float, optional): LARS coefficient as used in the paper (default: 1e-3) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + """ + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr=1e-3, + momentum=0, + eeta=1e-3, + weight_decay=0, + epsilon=0.0 + ) -> None: + if not isinstance(lr, float) or lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay)) + if eeta <= 0 or eeta > 1: + raise ValueError("Invalid eeta value: {}".format(eeta)) + if epsilon < 0: + raise ValueError("Invalid epsilon value: {}".format(epsilon)) + defaults = dict(lr=lr, momentum=momentum, + weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + eeta = group['eeta'] + lr = group['lr'] + lars = group['lars'] + eps = group['epsilon'] + + for p in group['params']: + if p.grad is None: + continue + decayed_grad = p.grad + scaled_lr = lr + if lars: + w_norm = torch.norm(p) + g_norm = torch.norm(p.grad) + trust_ratio = torch.where( + w_norm > 0 and g_norm > 0, + eeta * w_norm / (g_norm + weight_decay * w_norm + eps), + torch.ones_like(w_norm) + ) + trust_ratio.clamp_(0.0, 50) + scaled_lr *= trust_ratio.item() + if weight_decay != 0: + decayed_grad = decayed_grad.add(p, alpha=weight_decay) + decayed_grad = torch.clamp(decayed_grad, -10.0, 10.0) + + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone( + decayed_grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(decayed_grad) + decayed_grad = buf + + p.add_(decayed_grad, alpha=-scaled_lr) + + return loss diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb435a90f61c7a96c29dbf7aff16dc63cb651c3 --- /dev/null +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -0,0 +1,160 @@ +import torch +import os +import tempfile +import math +from torch.nn.parameter import Parameter +from typing import Optional, List, Dict, Callable + + +class NVMeOptimizer(torch.optim.Optimizer): + """A base class for offloading optimizer states. + + Args: + params: parameters + defaults (dict): default dict + nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0. + offload_dir (Optional[str], optional): Directory to save NVMe offload files. + If it's ``None``, a random temporary directory will be used. Defaults to None. + + Raises: + ImportError: Raise if ``tensornvme`` is not installed. + """ + + def __init__(self, + params, + defaults: dict, + nvme_offload_fraction: float = 0.0, + offload_dir: Optional[str] = None) -> None: + assert 0.0 <= nvme_offload_fraction <= 1.0 + super().__init__(params, defaults) + self.nvme_offload_fraction = float(nvme_offload_fraction) + if self.nvme_offload_fraction > 0.0: + try: + from tensornvme import DiskOffloader + from tensornvme._C import get_backends + except ModuleNotFoundError: + raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer') + self.offload_dir = offload_dir or tempfile.mkdtemp() + backend = 'uring' if 'uring' in get_backends() else 'aio' + self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend) + else: + self.offload_dir = None + self.offloader = None + self.is_on_nvme: Dict[Parameter, bool] = {} + self.offloaded_numel: int = 0 + self.total_numel: int = self._get_numel() + self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction) + + self.prefetch_params: List[Parameter] = [] + self.param_to_prefetch_idx: Dict[Parameter, int] = {} + + def _get_numel(self) -> int: + numel = 0 + for group in self.param_groups: + for p in group['params']: + numel += p.storage().size() + return numel + + def _post_state_init(self, param: Parameter) -> None: + numel = param.storage().size() + if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel: + self.is_on_nvme[param] = True + self.offloaded_numel += numel + else: + self.is_on_nvme[param] = False + + def _setup_prefetch_params(self) -> List[Parameter]: + if self.offloader is None: + return + assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0 + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + if len(self.state[p]) > 0 and self.is_on_nvme[p]: + assert p.device.type == 'cpu' + self.param_to_prefetch_idx[p] = len(self.prefetch_params) + self.prefetch_params.append(p) + + def _pre_step(self, *state_keys: str) -> None: + self._setup_prefetch_params() + if self.offloader is None or len(self.prefetch_params) == 0: + return + state = self.state[self.prefetch_params[0]] + for key in state_keys: + self.offloader.async_read(state[key]) + + def _pre_update(self, param: Parameter, *state_keys: str) -> None: + if self.offloader is None or param not in self.param_to_prefetch_idx: + return + self.offloader.sync_read_events() + idx = self.param_to_prefetch_idx[param] + if idx + 1 < len(self.prefetch_params): + state = self.state[self.prefetch_params[idx + 1]] + for key in state_keys: + self.offloader.async_read(state[key]) + + def _post_update(self, param: Parameter, *state_keys: str) -> None: + if self.offloader is None: + return + self.offloader.sync_write_events() + if self.is_on_nvme[param]: + state = self.state[param] + for key in state_keys: + self.offloader.async_write(state[key]) + + def _post_step(self) -> None: + if self.offloader is not None: + self.offloader.synchronize() + self.prefetch_params.clear() + self.param_to_prefetch_idx.clear() + + def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: + """Performs a single optimization step (parameter update). + + Example: + + >>> self._pre_step('exp_avg', 'exp_avg_sq') + >>> for group in self.param_groups: + >>> for p in group['params']: + >>> if p.grad is None: + >>> continue + >>> state = self.state[p] + >>> if len(state) == 0: + >>> state['exp_avg'] = ... + >>> state['exp_avg_sq'] = ... + >>> self._post_state_init(p) + >>> if p.device.type == 'cpu': + >>> self._pre_update(p, 'exp_avg', 'exp_avg_sq') + >>> adam() + >>> self._post_update(p, 'exp_avg', 'exp_avg_sq') + >>> else: + >>> ... + >>> self._post_step() + + Args: + closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + """ + raise NotImplementedError + + def state_dict(self) -> dict: + # TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM. + if self.offloader is not None: + raise NotImplementedError + return super().state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + # TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory. + if self.offloader is not None: + raise NotImplementedError + super().load_state_dict(state_dict) + + def __del__(self) -> None: + if getattr(self, 'offloader', None) is not None: + del self.offloader + if os.path.exists(self.offload_dir): + try: + os.rmdir(self.offload_dir) + except OSError: + pass diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2786d4496a8e938dc98cd8b254883bff82204283 --- /dev/null +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -0,0 +1,299 @@ +import math +from enum import Enum +from typing import Any, Dict, Set, Tuple + +import torch +import torch.distributed as dist +from torch.nn import Parameter +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam +from colossalai.nn.parallel.data_parallel import ZeroDDP +from colossalai.utils import disposable, get_current_device + +_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class ZeroOptimizer(ColossalaiOptimizer): + """A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). + + Note: + You must use ``ZeroDDP`` with ``ZeroOptimizer``. + + Note: + Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`, + if you set ``gpu_margin_mem_ratio > 0``. + + Args: + optim (Optimizer): An Optimizer instance. + module (ZeroDDP): A ``ZeroDDP`` instance. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + """ + + def __init__(self, + optim: Optimizer, + module: ZeroDDP, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, + **defaults: Any): + super().__init__(optim) + assert isinstance(module, ZeroDDP) + assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list" + self.module = module + self.gemini_manager = module.gemini_manager + self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager + self.optim_state = OptimState.UNSCALED + self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() + self.param_to_chunk32: Dict[Parameter, Chunk] = dict() + self.chunk16_set: Set[Chunk] = set() + self.clipping_flag = clipping_norm > 0.0 + self.max_norm = clipping_norm + + if self.clipping_flag: + assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" + + params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] + for p, fp32_p in zip(params_list, module.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + if chunk_16 not in self.chunk16_set: + chunk_16.l2_norm_flag = self.clipping_flag + self.chunk16_set.add(chunk_16) + + self.__init__optimizer() + + # Grad scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._logger = get_dist_logger() + + self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid + # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, + # and it must set `num_fp32_shards_per_param` correctly + self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr( + optim, 'num_fp32_shards_per_param', 0) >= 2 + if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail: + self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) + + self._register_states = disposable(self._register_states_) + + def _set_grad_ptr(self): + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + begin, end = self.param_to_range[fake_param] + chunk16 = chunk32.paired_chunk + + fake_param.data = chunk16.payload[begin:end] + fake_param.grad = fake_param.data + fake_param.data = chunk32.payload[begin:end] + + def _update_fp16_params(self): + none_tensor = torch.empty([0]) + for group in self.param_groups: + for fake_param in group['params']: + assert fake_param.grad is None + fake_param.data = none_tensor + + for chunk16 in self.chunk16_set: + chunk16.optim_update() + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(self.module.overflow_counter) + + # all-reduce across global group + dist.all_reduce(self._found_overflow) + + return self._found_overflow.item() > 0 + + def _calc_global_norm(self) -> float: + norm_sqr: float = 0.0 + group_to_norm = dict() + for c16 in self.chunk16_set: + assert c16.l2_norm is not None + + if c16.is_gathered: + norm_sqr += c16.l2_norm + else: + # this chunk is sharded, use communication to collect total norm + if c16.torch_pg not in group_to_norm: + group_to_norm[c16.torch_pg] = 0.0 + group_to_norm[c16.torch_pg] += c16.l2_norm + + c16.l2_norm = None # clear l2 norm + + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + for group, part_norm in group_to_norm.items(): + comm_buffer.fill_(part_norm) + dist.all_reduce(comm_buffer, group=group) + norm_sqr += comm_buffer.item() + + global_norm = math.sqrt(norm_sqr) + return global_norm + + def _get_combined_scale(self): + loss_scale = 1 + + if self.optim_state == OptimState.SCALED: + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale + if self.clipping_flag: + total_norm = self._calc_global_norm() + clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + if clip > 1: + combined_scale = clip * loss_scale + + if combined_scale == 1: + return -1 + else: + return combined_scale + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def zero_grad(self, *args, **kwargs): + self.module.overflow_counter = 0 + return self.optim.zero_grad(set_to_none=True) + + def step(self, *args, **kwargs): + self._maybe_move_fp32_params() + self._set_grad_ptr() + + found_inf = self._check_overflow() + if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f'Found overflow. Skip step') + self.zero_grad() # reset all gradients + self._update_fp16_params() + return + + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() + self.grad_scaler.update(found_inf) + + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) + self._register_states() + self.zero_grad() + self._update_fp16_params() + return ret + + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): + raise NotImplementedError + + def backward(self, loss: torch.Tensor): + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.module.backward(loss) + + def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED + self.module.backward_by_grad(tensor, grad) + + def _maybe_move_fp32_params(self): + if self._should_move_fp32_params_h2d: + self._should_move_fp32_params_h2d = False + available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio + fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param + fp32_params_used_cuda_margin_mem = 0 + + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + chunk16 = chunk32.paired_chunk + + if chunk32.device_type == 'cuda': + continue + + if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: + self.chunk_manager.move_chunk(chunk32, get_current_device()) + # stores grad now + self.chunk_manager.move_chunk(chunk16, get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_current_device()) + fp32_params_used_cuda_margin_mem += chunk32.payload_mem + + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + if chunk32.device_type == 'cuda': + state = self.optim.state[fake_param] + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(get_current_device()) + + def _register_states_(self): + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for val in state.values(): + if isinstance(val, torch.Tensor): + self.chunk_manager.add_extern_static_tensor(val) + + def __init__optimizer(self): + + def get_range_pair(local_chunk: Chunk, local_param: Parameter): + param_info = local_chunk.tensors_info[local_param] + if local_chunk.keep_gathered: + return param_info.offset, param_info.end + begin = max(0, param_info.offset - local_chunk.shard_begin) + end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) + return begin, end + + for group in self.optim.param_groups: + fake_params_list = list() + + for param in group['params']: + chunk16 = self.chunk_manager.get_chunk(param) + range_pair = get_range_pair(chunk16, param) + if range_pair[0] >= range_pair[1]: + continue + + fake_param = torch.nn.Parameter(torch.empty([0])) + self.param_to_chunk32[fake_param] = chunk16.paired_chunk + self.param_to_range[fake_param] = range_pair + + fake_params_list.append(fake_param) + + group['params'] = fake_params_list diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c369bfce22fa069a7f5335e1792c2e84a038717 --- /dev/null +++ b/colossalai/nn/parallel/__init__.py @@ -0,0 +1,4 @@ +from .data_parallel import ColoDDP, ZeroDDP +from .gemini_parallel import GeminiDDP + +__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP'] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..54f6eb9b739617bbc397fcb3052243cd29416277 --- /dev/null +++ b/colossalai/nn/parallel/data_parallel.py @@ -0,0 +1,581 @@ +import itertools +from collections import OrderedDict +from functools import partial +from typing import Dict, Iterable, List, Optional, Set + +import torch +import torch.distributed as dist + +from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer import OrderedParamGenerator +from colossalai.logging import get_dist_logger +from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device +from colossalai.zero.utils.gemini_hook import GeminiZeROHook + +from .reducer import Reducer + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +def _cast_float(args, dtype: torch.dtype): + if isinstance(args, torch.Tensor) and torch.is_floating_point(args): + args = args.to(dtype) + elif isinstance(args, (list, tuple)): + args = type(args)(_cast_float(t, dtype) for t in args) + elif isinstance(args, dict): + args = {k: _cast_float(v, dtype) for k, v in args.items()} + return args + + +class ColoDDP(torch.nn.Module): + """Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now. + + Example: + >>> from colossalai.core import global_context as gpc + >>> from colossalai.context import ParallelMode + >>> model = torch.nn.Linear(20, 1) + >>> pg = ProcessGroup(tp_degree = world_size//2) + >>> model = ColoDDP(model, pg) + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): Module to apply DDP. + process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses. + If it's None, the default data parallel group will be used. Defaults to None. + """ + + def __init__(self, + module: torch.nn.Module, + process_group: ColoProcessGroup, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True) -> None: + assert not isinstance(module, ColoDDP) + super().__init__() + self.module = module + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + assert process_group + + self.process_group = process_group + self.dp_world_size = self.process_group.dp_world_size() + + self.reducer = Reducer(bucket_cap_mb) + self.rebuild_bucket = rebuild_bucket + for p in module.parameters(): + if getattr(p, '_ddp_to_ignore', False): + continue + if p.requires_grad: + p.register_hook(partial(self.grad_handle, p)) + + def parameters(self, recurse: bool = True): + return self.module.parameters(recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.module.named_parameters(prefix, recurse) + + def named_buffers(self, prefix: str = '', recurse: bool = True): + return self.module.named_buffers(prefix, recurse) + + def named_children(self): + return self.module.named_children() + + def named_modules(self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + return self.module.named_modules(memo, prefix, remove_duplicate) + + def forward(self, *args, **kwargs): + self.module.zero_grad(set_to_none=True) + return self.module(*args, **kwargs) + + def backward(self, loss: torch.Tensor): + loss.backward() + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + if self.rebuild_bucket: + self.reducer.free() + for p in self.module.parameters(): + if getattr(p, '_ddp_to_ignore', False): + continue + if p.grad.device.type != "cpu": + p.grad = p._saved_grad + + def grad_handle(self, p, grad): + if grad.device.type != "cpu": + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + self.reducer.all_reduce_async(grad, + group=self.process_group.dp_process_group(), + callback_fn=partial(self._save_grad, p)) + grad.record_stream(self.comm_stream) + else: + ColoDDP._save_grad(p, grad) + return empty_grad + + else: + # TODO(jiaruifang) fixme + self.process_group.set_cpu_groups() + dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group()) + return grad + + @staticmethod + def _save_grad(p, grad): + if hasattr(p, '_saved_grad'): + p._saved_grad.add_(grad) + else: + p._saved_grad = grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + for p in self.module.parameters(): + if getattr(p, '_saved_grad', None) is not None: + if set_to_none: + p._saved_grad = None + else: + if p._saved_grad.grad_fn is not None: + p._saved_grad.detach_() + else: + p._saved_grad.requires_grad_(False) + p._saved_grad.zero_() + + @staticmethod + def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: + """Sets parameters to be ignored by DDP. + This method must be called before initializing ColoDDP. + + Example: + >>> params_to_ignore = [] + >>> for p in module.parameters(): + >>> if should_ignore(p): + >>> params_to_ignore.append(p) + >>> ColoDDP.set_params_to_ignore(params_to_ignore) + >>> module = ColoDDP(module) + + Args: + params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored. + """ + for p in params_to_ignore: + p._ddp_to_ignore = True + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + return self.module.load_state_dict(state_dict, strict) + + +class ZeroDDP(ColoDDP): + """ZeRO DDP for ColoTensor. + Warning: Nested ZeroDDP is not supported now. + It is designed to be used with ChunkManager and GeminiManager. + For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. + + Args: + module (torch.nn.Module): Module to apply ZeRO-DP. + gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. + For more details, see the API reference of ``GeminiManager``. + pin_memory (bool): Chunks on CPU Memory use pin-memory. + force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False. + """ + + def __init__(self, + module: torch.nn.Module, + gemini_manager: GeminiManager, + pin_memory: bool = False, + force_outputs_fp32: bool = False) -> None: + super().__init__(module, process_group=ColoProcessGroup()) + self.gemini_manager = gemini_manager + self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(gemini_manager) + self.fp32_params: List[ColoTensor] = [] + self.overflow_counter = 0 + self.grads_device: Dict[torch.Tensor, torch.device] = {} + + cpu_offload = self.gemini_manager.policy_name != 'cuda' + + if self.gemini_manager._premade_memstats_: + # build chunk in param runtime visited order. + param_order = self.gemini_manager.memstats()._param_runtime_order + else: + # build chunk in param initialized order. + # Note: in this way, it can not get filter unused params during runtime. + param_order = OrderedParamGenerator() + for p in module.parameters(): + param_order.append(p) + + for p in param_order.generate(): + assert isinstance(p, ColoParameter) + + if getattr(p, '_ddp_to_ignore', False): + p.data = p.data.half() + continue + + fp32_data = p.data.float() + fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + p.data = p.data.half() + dp_world_size = p.process_group.dp_world_size() + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.fp32_params.append(fp32_p) + self.grads_device[p] = self.gemini_manager.default_device + self.chunk_manager.close_all_groups() + self._cast_buffers() + + params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)] + for p, fp32_p in zip(params_list, self.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + chunk_32 = self.chunk_manager.get_chunk(fp32_p) + chunk_32.init_pair(chunk_16) + + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + + self._logger = get_dist_logger() + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + self.module.zero_grad(set_to_none=True) + self.gemini_manager.pre_iter(*args) + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + if self.force_outputs_fp32: + return _cast_float(outputs, torch.float) + return outputs + + def _setup_grads_ptr(self): + for p in self.module.parameters(): + if getattr(p, '_ddp_to_ignore', False): + continue + p.grad = None + + def _post_backward(self): + assert self.chunk_manager.accessed_mem == 0 + self._setup_grads_ptr() + self._logger.debug( + f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + ) + self.gemini_manager.post_iter() + + def backward(self, loss: torch.Tensor): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def backward_by_grad(self, tensor, grad): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + torch.autograd.backward(tensor, grad) + self._post_backward() + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) + chunk = self.chunk_manager.get_chunk(p) + chunk.copy_tensor_to_chunk_slice(p, grad) + reduced = self.chunk_manager.reduce_chunk(chunk) + if reduced: + if chunk.is_gathered: + chunk.cuda_global_chunk.div_(chunk.pg_size) + else: + chunk.cuda_shard.div_(chunk.pg_size) + # check overflow elements + self.overflow_counter += chunk.has_inf_or_nan + # record l2 norm for gradient clipping + if chunk.l2_norm_flag: + chunk.set_l2_norm() + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + return empty_grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + + def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: + for tensor in chunk.get_tensors(): + self.grads_device[tensor] = device + + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + r"""Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example: + + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) + + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in param_to_save_data + param_to_save_data[tensor] = record_tensor + + del temp_chunk + + for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + if p is not None: + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + destination[prefix + name] = record_parameter + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + prefix = '' + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + def load(param_name, dest_tensor, copy_func): + state_key = prefix + param_name + if state_key in state_dict: + input_param = state_dict[state_key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != dest_tensor.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(state_key, input_param.shape, + dest_tensor.shape)) + return + try: + with torch.no_grad(): + copy_func(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), + input_param.size(), ex.args)) + elif strict: + missing_keys.append(state_key) + + def load_fp32_parameter(chunk_slice, data): + chunk_slice.copy_(data.flatten()) + + fp32_to_name = dict() + for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + if p is not None: + fp32_to_name[fp32_p] = name + + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + parameter_name = fp32_to_name[tensor] + parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + + if chunk.is_gathered: + chunk.cuda_global_chunk.copy_(temp_chunk) + elif chunk.cuda_shard is not None: + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + else: + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + + del temp_chunk + + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.optim_update() + + for name, buf in persistent_buffers.items(): + if buf is not None: + load(name, buf, buf.copy_) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", + torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + if input_name not in local_state: + unexpected_keys.append(key) + + def _cast_buffers(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.half() diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5ef424a1d9e4f4c9490b368be0f6babdf0438c --- /dev/null +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch + +from colossalai.gemini.chunk import init_chunk_manager +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer import MemStats + +from .data_parallel import ZeroDDP + + +class GeminiDDP(ZeroDDP): + + def __init__(self, + module: torch.nn.Module, + device: torch.device, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: Optional[float] = None, + memstats: Optional[MemStats] = None) -> None: + """ + A torch.Module warpper using ZeRO-DP and Genimi. + ZeRO is for parallel. Gemini is for memory management. + WARNING: The class will modify the module inline! + + Example: + model is initialized under the context of ColoInitContext + >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): the model to be wrapped. + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + """ + chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29b8353e63c5930950d28d58b0c122c76486c3e1 --- /dev/null +++ b/colossalai/nn/parallel/layers/__init__.py @@ -0,0 +1,14 @@ +from .colo_module import ColoModule +from .linear import ColoLinear +from .embedding import ColoEmbedding +from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module + +from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ + ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache + +__all__ = [ + 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', + 'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr', + 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelCachedEmbeddingBagTablewiseSpiltCache' +] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbc931a79dceeffcf827fc3fd3b18ca8bf87dd3 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -0,0 +1,13 @@ +from .cache_mgr import CachedParamMgr, EvictionStrategy +from .copyer import LimitBuffIndexCopyer +from .cached_embedding import CachedEmbeddingBag +from .parallel_cached_embedding import ParallelCachedEmbeddingBag +from .embedding_config import TablewiseEmbeddingBagConfig +from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise +from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache + +__all__ = [ + 'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy', + 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelCachedEmbeddingBagTablewiseSpiltCache' +] diff --git a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..705835a0ed22ef5c9f6ecc388ddcf8d4e2ea3073 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py @@ -0,0 +1,36 @@ +import abc +import torch.nn as nn + + +class BaseEmbeddingBag(abc.ABC, nn.Module): + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + ): + super(BaseEmbeddingBag, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + # Specific to embedding bag + self.mode = mode + self.include_last_offset = include_last_offset diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..da043df368ae17ed1b95cc42b90e58994c84f9aa --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -0,0 +1,583 @@ +import numpy as np +import torch +from torch.profiler import record_function +from typing import List, Optional +from contexttimer import Timer +from .copyer import LimitBuffIndexCopyer +from enum import Enum +import sys +from contextlib import contextmanager + + +class EvictionStrategy(Enum): + LFU = 1 + # dataset aware eviction strategy + DATASET = 2 + + +def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: + if stream is None: + return + torch.cuda.current_stream().wait_stream(stream) + # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, + # PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is + # freed, its memory is likely to be reused by newly constructed tenosrs. By default, + # this allocator traces whether a tensor is still in use by only the CUDA stream where it + # was created. When a tensor is used by additional CUDA streams, we need to call record_stream + # to tell the allocator about all these streams. Otherwise, the allocator might free the + # underlying memory of the tensor once it is no longer used by the creator stream. This is + # a notable programming trick when we write programs using multi CUDA streams. + cur_stream = torch.cuda.current_stream() + assert isinstance(t, torch.Tensor) + t.record_stream(cur_stream) + + +class CachedParamMgr(torch.nn.Module): + """ + Manage Embedding Weights on CPU and CUDA memory uses a software cache. + CPU maintains the entire original weight. + CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`. + During training, GPU needs to transmit embedding rows between CPU and GPU. + Args: + weight (torch.Tensor): the weight of the Embedding layer. + cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0. + buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000. + pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False. + evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options. + `EvictionStrategy.LFU`: use the least frequently used cache. + `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume. + Defaults to EvictionStrategy.DATASET. + """ + + def __init__( + self, + weight: torch.Tensor, + cuda_row_num: int = 0, + buffer_size: int = 0, + pin_weight: bool = True, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + async_copy: bool = False, + ) -> None: + super(CachedParamMgr, self).__init__() + self.buffer_size = buffer_size + self.num_embeddings, self.embedding_dim = weight.shape + self.cuda_row_num = cuda_row_num + self._cuda_available_row_num = self.cuda_row_num + self.pin_weight = pin_weight + self.elem_size_in_byte = weight.element_size() + + # weight configure + self._init_weight(weight) + + # Perf log + self.num_hits_history = [] + self.num_miss_history = [] + self.num_write_back_history = [] + + self._evict_strategy = evict_strategy + + self._async_copy = async_copy + + if self._async_copy: + self._memcpy_stream = torch.cuda.Stream() + + print('use async copy') + + if self._evict_strategy == EvictionStrategy.LFU: + # cache_row_idx -> frequency, freq of the cache rows. + # classic lfu cache. evict the minimal freq value row in cuda cache. + self.register_buffer("freq_cnter", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(sys.maxsize), + persistent=False) + self._elapsed_dict = {} + self._show_cache_miss = True + self._reset_comm_stats() + + def _reset_comm_stats(self): + for k in self._elapsed_dict.keys(): + self._elapsed_dict[k] = 0 + + self._cpu_to_cuda_numel = 0 + self._cuda_to_cpu_numel = 0 + if self._show_cache_miss: + self._cache_miss = 0 + self._total_cache = 0 + + @contextmanager + def timer(self, name): + with Timer() as t: + yield + torch.cuda.synchronize() + + if name not in self._elapsed_dict.keys(): + self._elapsed_dict[name] = 0 + self._elapsed_dict[name] += t.elapsed + + def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: + """_find_evict_gpu_idxs + Find the gpu idxs to be evicted, according to their freq. + Args: + evict_num (int): how many rows has to be evicted + Returns: + torch.Tensor: a list tensor (1D), contains the gpu_row_idxs. + """ + if self._evict_strategy == EvictionStrategy.LFU: + # find the minimal evict_num freq entries in cached_idx_map + _, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False) + return evict_gpu_row_idxs + elif self._evict_strategy == EvictionStrategy.DATASET: + # cached_idx_map itself implies the priority of eviction. + # The value of self.cached_idx_map represents cpu_row_idx. + # The larger it is, the less frequently it will appear in the dataset, + # and the higher its eviction priority will be. + _, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True) + return evict_gpu_row_idxs + else: + raise TypeError + + def _init_weight(self, weight): + if self.cuda_row_num > 0: + # Enable cache with introducing auxiliary data structures + self.cuda_cached_weight = torch.nn.Parameter( + torch.zeros(self.cuda_row_num, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=weight.dtype)) + + # pin memory cpu for higher CPU-GPU copy bandwidth + self.weight = weight.pin_memory() if self.pin_weight else weight + # map original id to new id with respect to frequency + # id -> cpu_row_idx + self.register_buffer( + "idx_map", + torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), + persistent=False, + ) + + # cached_idx_map: gpu_row_idx -> cpu_row_idx + self.register_buffer("cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + # cpu_row_id -> gpu_row_idx. + # gpu_row_idx as -1 means cpu_row_id not in CUDA. + self.register_buffer("inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) + + # index copy buffer size should less than 10% of cuda weight. + if self.buffer_size > 0: + self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size) + + else: + # Disable cache so that FreqCacheEmbedding is compatible with vanilla EmbeddingBag + # self.weight = torch.nn.Parameter(weight) + # self.cuda_cached_weight = self.weight + raise NotImplementedError() + + def cpu_weight_data(self, row_idx: int) -> torch.Tensor: + """ + access a row of CPU weight. + Args: + row_idx (int): the idx of rows + Returns: + torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. + """ + + return self.weight.data.view(-1).narrow(0, + int(row_idx) * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + + @property + def cuda_available_row_num(self): + return self._cuda_available_row_num + + @torch.no_grad() + def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7): + """reorder + reorder the weight according to ids' frequency in dataset before training. + Execute only once before training, also known as warmup phase. + + Note: + If you would like to use the DATASET as the eviction strategy, you must call this function. + Note: + If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize + The frequency in LFU cache using the dataset statistics. + Args: + ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. + warmup_ratio (float): the amount of chunks preloaded in cuda cache + """ + # reorder phase: reorder the cpu weight according to their freq stats in the target dataset. + # reorder only works for DATASET eviction strategy. + + if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor): + ids_freq_mapping = torch.tensor(ids_freq_mapping) + + if self._evict_strategy == EvictionStrategy.DATASET: + if ids_freq_mapping is not None: + tmp_idx = torch.argsort(ids_freq_mapping, descending=True) + sorted_idx = torch.argsort(tmp_idx) + self.idx_map.data.copy_(sorted_idx) + + # warmup phase: copy #preload_row_num rows from cpu to gpu. + preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) + if preload_row_num > 0: + with Timer() as timer: + # extract rows from cpu weight + if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None: + freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) + preload_cuda_row_idxs = torch.arange(preload_row_num).cuda() + else: + preload_cpu_ids = torch.arange(preload_row_num) + preload_cuda_row_idxs = preload_cpu_ids.cuda() + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=preload_cpu_ids, + tgt_index=preload_cuda_row_idxs, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + else: + preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, + preload_rows) + + # update auxiliary info + self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() + self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs + self._cuda_available_row_num -= preload_row_num + + if self._evict_strategy == EvictionStrategy.LFU: + # if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset. + if ids_freq_mapping is None: + self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) + else: + self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() + + print(f'Cache warmup finished cost {timer.elapsed} sec.') + + def flush(self): + """flush all CUDA rows to CPU. + The function is usually called after training finished. + """ + slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) + row_ids = self.cached_idx_map[slots] + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() + self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) + self.cached_idx_map.index_fill_(0, slots, -1) + self.inverted_cached_idx.index_fill_(0, row_ids, -1) + self._cuda_available_row_num += slots.numel() + + if self._show_cache_miss: + self._cache_miss = 0 + self._total_cache = 0 + + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter.fill_(sys.maxsize) + assert self._cuda_available_row_num == self.cuda_row_num + assert torch.all(self.inverted_cached_idx == -1).item() + assert torch.all(self.cached_idx_map == -1).item() + + def print_comm_stats(self): + if self._cuda_to_cpu_numel > 0 and "3_evict_out" in self._elapsed_dict: + elapsed = self._elapsed_dict["3_evict_out"] + print( + f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" + ) + print(f'cuda_to_cpu_elapse {elapsed} sec') + if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict: + elapsed = self._elapsed_dict["5_evict_in"] + print( + f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" + ) + print(f'cpu_to_cuda_elpase {elapsed} sec') + + for k, v in self._elapsed_dict.items(): + print(f'{k}: {v}') + + print(f'cache miss ratio {self._cache_miss / self._total_cache}') + + @torch.no_grad() + def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: + """ + convert ids to indices in self.cuda_cached_weight. + Implemented with parallel operations on GPU. + Args: + ids (torch.Tensor): ids from the dataset + Returns: + torch.Tensor: contains indices in self.cuda_cached_weight + """ + ids = self.idx_map.index_select(0, ids.view(-1)) + ret = self.inverted_cached_idx.index_select(0, ids) + return ret + + @torch.no_grad() + def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: + """ + move the cpu embedding rows w.r.t. ids into CUDA memory + Args: + ids (torch.Tensor): the ids to be computed + Returns: + torch.Tensor: indices on the cuda_cached_weight. + """ + torch.cuda.synchronize() + with self.timer("cache_op") as gtimer: + # identify cpu rows to cache + with self.timer("1_identify_cpu_row_idxs") as timer: + with record_function("(cache) get unique indices"): + if self._evict_strategy == EvictionStrategy.LFU: + cpu_row_idxs, repeat_times = torch.unique(ids, return_counts=True) + else: + cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) + + assert len(cpu_row_idxs) <= self.cuda_row_num, \ + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + f"Please increase cuda_row_num or decrease the training batch size." + self.evict_backlist = cpu_row_idxs + tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True) + comm_cpu_row_idxs = cpu_row_idxs[tmp] + + if self._show_cache_miss: + self._cache_miss += torch.sum(repeat_times[tmp]) + self._total_cache += ids.numel() + + self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) + self.num_miss_history.append(len(comm_cpu_row_idxs)) + self.num_write_back_history.append(0) + + # move sure the cuda rows will not be evicted! + with record_function("(cache) prepare_rows_on_cuda"): + with self.timer("prepare_rows_on_cuda") as timer: + self._prepare_rows_on_cuda(comm_cpu_row_idxs) + + self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) + + with self.timer("6_update_cache") as timer: + with record_function("6_update_cache"): + gpu_row_idxs = self._id_to_cached_cuda_id(ids) + + # update for LFU. + if self._evict_strategy == EvictionStrategy.LFU: + unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] + self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) + + return gpu_row_idxs + + def _row_in_cuda(self, row_id: int) -> bool: + return self.inverted_cached_idx[row_id] != -1 + + @torch.no_grad() + def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: + """prepare rows in cpu_row_idxs on CUDA memory + Args: + cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA + """ + evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num + + cpu_row_idxs_copy = cpu_row_idxs.cpu() + + # move evict in rows to gpu + if self._async_copy: + if self.buffer_size == 0: + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).pin_memory() + with torch.cuda.stream(self._memcpy_stream): + evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) + else: + raise NotImplemented + + if evict_num > 0: + with self.timer("2_identify_cuda_row_idxs") as timer: + mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) + if self._evict_strategy == EvictionStrategy.DATASET: + # mask method. + # set cached_idx_map[invalid_idxs] to -2. + # so those idxs will be sorted to end, therefore not being chosen as victim + backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() + self.cached_idx_map.index_fill_(0, invalid_idxs, -2) + + with self.timer("2_1_find_evict_gpu_idxs") as timer: + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + # move evict out rows to cpu + if self._async_copy: + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + with torch.cuda.stream(None): + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) + self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) + + elif self._evict_strategy == EvictionStrategy.LFU: + with self.timer("2_1_backup_freqs") as timer: + backup_freqs = self.freq_cnter[invalid_idxs].clone() + self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) + + with self.timer("2_2_find_evict_gpu_idxs") as timer: + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + if self._async_copy: + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + with torch.cuda.stream(None): + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) + + with self.timer("2_3_revert_freqs") as timer: + self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) + + evict_info = self.cached_idx_map[evict_gpu_row_idxs] + + with self.timer("3_evict_out") as timer: + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=evict_gpu_row_idxs, + tgt_index=evict_info.cpu(), + src=self.cuda_cached_weight.view(self.cuda_row_num, -1), + tgt=self.weight.view(self.num_embeddings, -1)) + else: + # allocate tmp memory on CPU and copy rows on CUDA to CPU. + # TODO async gpu -> cpu + if self._async_copy: + _wait_for_data(evict_out_rows_cpu, None) + else: + with self.timer("3_1_evict_out_index_select") as timer: + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer: + evict_out_rows_cpu = evict_out_rows_cpu.cpu() + + with self.timer("3_2_evict_out_cpu_copy") as timer: + self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) + + self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) + self.inverted_cached_idx.index_fill_(0, evict_info, -1) + # self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary + self._cuda_available_row_num += evict_num + + weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim + self._cuda_to_cpu_numel += weight_size + # print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + + # slots of cuda weight to evict in + with self.timer("4_identify_cuda_slot") as timer: + slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + + # TODO wait for optimize + with self.timer("5_evict_in") as timer: + # Here also allocate extra memory on CUDA. #cpu_row_idxs + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=cpu_row_idxs_copy, + tgt_index=slots, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + else: + if self._async_copy: + _wait_for_data(evict_in_rows_gpu, self._memcpy_stream) + else: + with self.timer("5_1_evict_in_index_select") as timer: + # narrow index select to a subset of self.weight + # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) + # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).pin_memory() + + with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer: + evict_in_rows_gpu = evict_in_rows_gpu.cuda() + + with self.timer("5_3_evict_in_index_copy") as timer: + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) + + with self.timer("6_update_cache") as timer: + self.cached_idx_map[slots] = cpu_row_idxs + self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots) + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter.index_fill_(0, slots, 0) + self._cuda_available_row_num -= cpu_row_idxs.numel() + + weight_size = cpu_row_idxs.numel() * self.embedding_dim + self._cpu_to_cuda_numel += weight_size + # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + + def _find_free_cuda_row(self) -> int: + if self._cuda_available_row_num == 0: + return -1 + candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1) + return candidates[0].item() + + def _evict(self) -> int: + """ + deprecated + evict one row from cuda to cpu. + Returns: + (int) : the slot id be evicted. + """ + mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1) + buf = self.cached_idx_map[mask].clone() + idx = torch.nonzero(mask).squeeze(1) + self.cached_idx_map.index_fill_(0, idx, -1) + max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0) + max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx] + + if max_gpu_row_idx == -1: + raise RuntimeError("Can not evict a row") + + max_gpu_row_idx = max_gpu_row_idx.item() + max_offset = self.inverted_cached_idx[max_gpu_row_idx] + # recover + self.cached_idx_map.index_copy_(0, idx, buf) + + with Timer() as timer: + cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor) + + # update inverted_cached_idx, min_slot_id is evicted from cuda + self.cached_idx_map[max_cpu_row_idx] = -1 + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter[max_cpu_row_idx] = sys.maxsize + self.inverted_cached_idx[max_gpu_row_idx] = -1 + + self._cuda_available_row_num += 1 + + self._cuda_to_cpu_numel += self.embedding_dim + # self.num_write_back_history[-1] += 1 + return max_cpu_row_idx + + @torch.no_grad() + def _admit(self, row_id: int): + """ + deprecated + move in row_id to CUDA + Args: + row_id (int): the id of row to be moved in + """ + # find a free slot in partial cuda weight + slot_id = self._find_free_cuda_row() + + if slot_id == -1: + # evict one row + slot_id = self._evict() + slot_offset = slot_id + # copy payload from cpu to cuda + with Timer() as timer: + cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor.data.copy_(self.cpu_weight_data(row_id)) + + # update the inverted_cached_idx + self.cached_idx_map[slot_id] = row_id + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter[slot_id] = 0 + self.inverted_cached_idx[row_id] = slot_offset + + self._cuda_available_row_num -= 1 + + self._cpu_to_cuda_numel += self.embedding_dim diff --git a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c45d8e80c028637a5c964b41062846b7020c7b --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -0,0 +1,157 @@ +import torch +import torch.nn.functional as F +from typing import List, Optional, Iterator, Tuple, Union + +from .base_embedding import BaseEmbeddingBag +from .cache_mgr import CachedParamMgr, EvictionStrategy +from torch.nn.parameter import Parameter + + +class CachedEmbeddingBag(BaseEmbeddingBag): + """CachedEmbeddingBag + + Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space. + It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`. + You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed โ€œpadโ€. For a newly constructed EmbeddingBag, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector. Note that the embedding vector at padding_idx is excluded from the reduction. + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm + norm_type (str, optional): The p of the p-norm to compute for the max_norm option. Defaults to 2.. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. Note: this option is not supported when mode="max". Defaults to False. + sparse (bool, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".. Defaults to False. + _weight (torch.Tensor, optional): an embedding weight tensor. Concate multiple tables in a embedding bag as a single one. Defaults to None. + mode (str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean". Defaults to 'mean'. + include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. + dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. + device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. + cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row + ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. + warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. + buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. + pin_weight (bool, optional): pin the cpu weight. Defaults to False. + evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + max_norm: float = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[torch.Tensor] = None, + mode: str = 'mean', + include_last_offset: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + cache_ratio: float = 0.01, + ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + pin_weight: bool = False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, + scale_grad_by_freq, sparse, mode, include_last_offset) + + assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" + self.evict_strategy = evict_strategy + if _weight is None: + _weight = self._weight_alloc(dtype, device) + cuda_row_num = int(num_embeddings * cache_ratio) + # configure weight & cache + self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) + self.cache_op = True + + def set_cache_mgr_async_copy(self, flag): + self.cache_weight_mgr._async_copy = flag + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + return weight + + def _preprocess(self, + weight, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False): + """ + Called after initialized. + Reorder the weight rows according to the ids_freq_mapping. + Then, let the weights of the Module be managed by a CachedParamMgr. + + Args: + cuda_row_num (int): number of rows can be hosted in CUDA memory + ids_freq_mapping (List[int]): a list, idx is id number, value is freq + warmup_ratio (float): the amount of rows preloaded in cuda cache + """ + self.cache_weight_mgr = CachedParamMgr(weight, + cuda_row_num, + buffer_size, + pin_weight, + evict_strategy=self.evict_strategy) + self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) + + def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): + if self.cache_op: + with torch.no_grad(): + input = self.cache_weight_mgr.prepare_ids(input) + + embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, self.padding_idx) + if shape_hook is not None: + embeddings = shape_hook(embeddings) + return embeddings + + @property + def weight(self): + return self.cache_weight_mgr.weight + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield 'weight', self.cache_weight_mgr.cuda_cached_weight + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield self.cache_weight_mgr.cuda_cached_weight + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + +############################# Perf Log ################################### + + @property + def num_hits_history(self): + return self.cache_weight_mgr.num_hits_history + + @property + def num_miss_history(self): + return self.cache_weight_mgr.num_miss_history + + @property + def num_write_back_history(self): + return self.cache_weight_mgr.num_write_back_history + + @property + def swap_in_bandwidth(self): + if self.cache_weight_mgr._cpu_to_cuda_numel > 0: + return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ + self.cache_weight_mgr._cpu_to_cuda_elpase + else: + return 0 + + @property + def swap_out_bandwidth(self): + if self.cache_weight_mgr._cuda_to_cpu_numel > 0: + return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ + self.cache_weight_mgr._cuda_to_cpu_elapse + return 0 diff --git a/colossalai/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/nn/parallel/layers/cache_embedding/copyer.py new file mode 100644 index 0000000000000000000000000000000000000000..b586be1dc6d98ed7df2cffd1c326ca25ab837f33 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/copyer.py @@ -0,0 +1,49 @@ +import torch +from torch import LongTensor + + +class LimitBuffIndexCopyer(object): + """LimitBuffIndexCopyer + Index Copy using limited temp buffer on CUDA. + + Args: + size (int): buffer size + """ + + def __init__(self, size: int) -> None: + self._buff_size = size + + @torch.no_grad() + def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor): + """copy + src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index] + The valid rows in the src tensor are continous, while rows in tgt tensor is scattered. + + Args: + dim (int): dimension along which to index + src_index (int): indices of src tensor to select from + tgt_index (int): indices of tgt tensor to select from + src (torch.Tensor): the tensor containing values to copy + tgt (torch.Tensor): the tensor to be copied + """ + # tgt.index_copy_(dim, index, src) + assert dim == 0, "only support index_copy on dim 0" + assert tgt.dim() == 2 + assert src.dim() == 2 + tgt_device = tgt.device + src_device = src.device + + assert src_index.numel() == tgt_index.numel() + dim_size = src_index.numel() + src_index = src_index.to(src_device) + for begin_pos in range(0, dim_size, self._buff_size): + cur_len = min(self._buff_size, dim_size - begin_pos) + src_idx_piece = src_index.narrow(0, begin_pos, cur_len) + if src_device.type == 'cpu' and tgt_device.type == 'cuda': + cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory() + tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device) + tmp_buffer.copy_(cpu_tmp_buffer) + else: + tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device) + tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len) + tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer) diff --git a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py new file mode 100644 index 0000000000000000000000000000000000000000..36e04c833feb4203d9033a15951b580207890e8b --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py @@ -0,0 +1,27 @@ +import torch + + +class TablewiseEmbeddingBagConfig: + ''' + example: + def prepare_tablewise_config(args, cache_ratio, ...): + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + ... + return embedding_bag_config_list + ''' + + def __init__(self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = ""): + self.num_embeddings = num_embeddings + self.cuda_row_num = cuda_row_num + self.assigned_rank = assigned_rank + self.buffer_size = buffer_size + self.ids_freq_mapping = ids_freq_mapping + self.initial_weight = initial_weight + self.name = name diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f77e195f4b480da53e6f1772145d122afb1e69 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -0,0 +1,141 @@ +import torch +import torch.nn.functional as F +from typing import List, Optional, Iterator, Tuple + +from .cached_embedding import CachedEmbeddingBag +from colossalai.nn._ops._utils import dual_all_to_all + +from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor +from .cache_mgr import CachedParamMgr, EvictionStrategy + + +def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: + if world_size == 1: + return 0, embedding_dim, True + + assert embedding_dim >= world_size, \ + f"Embedding dimension {embedding_dim} must be larger than the world size " \ + f"{world_size} of the process group" + chunk_size = embedding_dim // world_size + threshold = embedding_dim % world_size + # if embedding dim is divisible by world size + if threshold == 0: + return rank * chunk_size, (rank + 1) * chunk_size, True + + # align with the split strategy of torch.tensor_split + size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)] + offset = sum(size_list[:rank]) + return offset, offset + size_list[rank], False + + +class ParallelCachedEmbeddingBag(CachedEmbeddingBag): + + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + self.partition_start_index, self.partition_end_index, divisible = get_partition( + embedding_dim, self.rank, self.world_size) + self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index + + super(ParallelCachedEmbeddingBag, + self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, + warmup_ratio, buffer_size, pin_weight, evict_strategy) + self.cache_op = True + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D) + return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) + + def forward( + self, + indices, + offsets=None, + per_sample_weights=None, + shape_hook=None, + scatter_dim=0, + gather_dim=-1, + ): + if self.cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(indices) + output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, self.padding_idx) + if shape_hook is not None: + output_shard = shape_hook(output_shard) + output_full = dual_all_to_all(output_shard, + self.weight.get_process_group(), + scatter_dim=scatter_dim, + gather_dim=gather_dim) + return output_full + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + @classmethod + def from_pretrained( + cls, + embedding: torch.Tensor, + freeze: bool = True, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + mode: str = 'mean', + include_last_offset: bool = False, + cuda_row_num: int = 100_000, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + ) -> 'ParallelCachedEmbeddingBag': + rows, cols = embedding.shape + embedding_bag = cls(rows, + cols, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + embedding, + mode, + include_last_offset, + cuda_row_num=cuda_row_num, + ids_freq_mapping=ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=buffer_size) + embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze + return embedding_bag + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py new file mode 100644 index 0000000000000000000000000000000000000000..949f85ad4baf894d4e53f06594bcbd9080f249ae --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -0,0 +1,198 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from .cached_embedding import CachedEmbeddingBag +from .cache_mgr import EvictionStrategy +from .embedding_config import TablewiseEmbeddingBagConfig +from colossalai.tensor import ProcessGroup +from colossalai.nn._ops._utils import dual_all_to_all_tablewise + +from typing import List +import time + + +class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): + """ + all tables assigned to this class instance are managed by a single CachedEmbeddingBag. + Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. + """ + + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + self.assigned_table_list: List[int] = [] + self.pg = ProcessGroup(tp_degree=self.world_size) + self.num_embeddings = 0 + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.num_embeddings += self.global_table_num_embeddings_list[i] + self.include_last_offset = include_last_offset + + ids_freq_mapping = [] + for config in embedding_bag_config_list: + if config.assigned_rank == self.rank: + if config.ids_freq_mapping != None: + ids_freq_mapping.extend(config.ids_freq_mapping) + else: + ids_freq_mapping = None + break + self.cache_ratio = cache_ratio + # table-associate cache + cuda_row_num = int(cache_ratio * self.num_embeddings) + super(ParallelCachedEmbeddingBagTablewise, + self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, + warmup_ratio, buffer_size, pin_weight, evict_strategy) + + # for assigned tables reconnection: + self.idx_offset_list = [] + offset_cumsum = 0 + for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list): + if self.rank_of_tables[table_i] == self.rank: + self.idx_offset_list.append(offset_cumsum) + else: + offset_cumsum += table_num_embeddings + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + self.cache_op = True + + def forward( + self, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None, + shape_hook=None, + already_split_along_rank=True, + ): + if not already_split_along_rank: + # not recommanded. it takes time. + batch_size = (offsets.shape[0]) // self.global_tables_num + local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( + batch_size, indices, offsets, per_sample_weights) + else: + # recommanded. + batch_size = (offsets.shape[0]) // len(self.assigned_table_list) + local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights + if self.cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(local_indices) + local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, + self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + local_per_sample_weights, self.include_last_offset, self.padding_idx) + local_output = torch.cat(local_output.split(batch_size), 1) + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def split_along_rank(self, + batch_size, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None): + ''' + if input indices and offsets haven't been splitted along assigned rank, this function will do it. + it takes time. please consider splitting data during batch loading. + ''' + local_indices_list: List(torch.Tensor) = [] + local_offsets_list: List(torch.Tensor) = [] + if per_sample_weights != None: + local_per_sample_weights_list: List(torch.Tensor) = [] + + offset_pre_end = 0 # local_offsets trick + for i, handle_table in enumerate(self.assigned_table_list): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till-the-end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + # alternative approach: reduce malloc + ''' + # 1. local_indices_list: + local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) + torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) + local_indices_list.append(local_indices) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + local_offsets_list.append(local_offsets) + else: + temp_holder = offsets[batch_size * handle_table].item() + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder + local_offsets_list.append(local_offsets) + ''' + # 1. local_indices_list: + local_indices_list.append( + indices.narrow(0, indices_start_position, + indices_end_position - indices_start_position).sub(self.idx_offset_list[i])) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).add(offset_pre_end - offsets[batch_size * + (handle_table)]) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets_list.append(local_offsets) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + offset_pre_end = local_offsets[-1] + local_offsets_list.append(local_offsets[:-1]) + # 3. local_per_sample_weights_list: + if per_sample_weights != None: + local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) + local_indices = torch.cat(local_indices_list, 0) + local_offsets = torch.cat(local_offsets_list, 0) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) + return local_indices, local_offsets, local_per_sample_weights + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4647028d477d6c89ddfcb0d68d906ba43aa4c8 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -0,0 +1,138 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.profiler import record_function + +from .cached_embedding import CachedEmbeddingBag + +from colossalai.tensor import ProcessGroup +from colossalai.nn._ops._utils import dual_all_to_all_tablewise +from .embedding_config import TablewiseEmbeddingBagConfig +from .cache_mgr import EvictionStrategy + +from typing import List +import abc + + +class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): + """ + every table assigned to this class instance is managed by a CachedEmbeddingBag. + """ + + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + + self.assigned_table_list: List[int] = [] + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.include_last_offset = include_last_offset + self.pg = ProcessGroup(tp_degree=self.world_size) + + # prepare CachedEmbeddingBag list + + self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList() + for config in embedding_bag_config_list: + if config.assigned_rank != self.rank: + continue + self.cached_embedding_bag_list.append( + CachedEmbeddingBag(num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy)) + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): + # determine indices to handle + batch_size = (offsets.shape[0]) // self.global_tables_num + local_output_list = [] + for i, handle_table in enumerate(self.assigned_table_list): + with record_function("(tablewise) prepare indices and offsets"): + with record_function("part 1"): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till the end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + with record_function("part 2"): + # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] + local_indices = indices.narrow(0, indices_start_position, indices_end_position - + indices_start_position).sub(self.global_tables_offsets[handle_table]) + if self.include_last_offset: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size + 1).sub(offsets[batch_size * (handle_table)]) + else: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).sub(offsets[batch_size * (handle_table)]) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] + with record_function("(tablewise) tablewise forward"): + local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets, + local_per_sample_weights)) + + # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) + local_output = torch.cat(local_output_list, 1) + # then concatenate those local_output on the second demension. + # use all_to_all + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def element_size(self): + if len(self.assigned_table_list) == 0: + return 0 + return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size() + + def print_comm_stats_(self): + cuda_to_cpu_elem_num = 0 + cpu_to_cuda_elem_num = 0 + for cached_embedding_bag in self.cached_embedding_bag_list: + cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel + cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel + print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem") + print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem") diff --git a/colossalai/nn/parallel/layers/colo_module.py b/colossalai/nn/parallel/layers/colo_module.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0f5d5f520a17c979a21048903b025dc642296f --- /dev/null +++ b/colossalai/nn/parallel/layers/colo_module.py @@ -0,0 +1,46 @@ +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor import ComputePattern +from typing import List, Dict + + +class ColoModule(object): + + def __init__(self): + self._shard_params: List[str] = [] + self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} + + def _register_shard_params(self, params: List[str]): + self._shard_params = params + + def _register_allowed_patterns(self, + compute_pattern: ComputePattern, + dist_specs: Dict[str, _DistSpec], + mode='default'): + assert list( + dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' + if not compute_pattern in self._allowed_patterns: + self._allowed_patterns[compute_pattern] = {} + self._allowed_patterns[compute_pattern][mode] = dist_specs + + def _set_default(self, compute_pattern: ComputePattern, target_mode): + self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode] + + def has_compute_pattern(self, compute_pattern: ComputePattern): + return compute_pattern in self._allowed_patterns + + def get_dist_specs(self, compute_pattern: ComputePattern): + assert self.has_compute_pattern(compute_pattern) + return self._allowed_patterns[compute_pattern] + + def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'): + return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] + + def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'): + assert self.has_compute_pattern_with_mode(compute_pattern, mode) + return self._allowed_patterns[compute_pattern][mode] + + def get_param_names(self): + return self._shard_params + + def register(self, compute_pattern, pg): + raise NotImplementedError diff --git a/colossalai/nn/parallel/layers/embedding.py b/colossalai/nn/parallel/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ccacc1ead297b349c7252aec37a41f106dff7993 --- /dev/null +++ b/colossalai/nn/parallel/layers/embedding.py @@ -0,0 +1,36 @@ +from .colo_module import ColoModule +from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec + + +class ColoEmbedding(ColoModule): + + def __init__(self): + super(ColoEmbedding, self).__init__() + self._register_shard_params(['weight']) + + def register(self, compute_pattern, pg: ProcessGroup): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D(pg) + + def _set_TP1D(self, pg: ProcessGroup): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([0], [pg.tp_world_size()]), + }, + mode='row', + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([-1], [pg.tp_world_size()]), + }, + mode='col', + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/nn/parallel/layers/linear.py b/colossalai/nn/parallel/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..84a8c042587dfdb279f4bd1f83e07d317bc57d08 --- /dev/null +++ b/colossalai/nn/parallel/layers/linear.py @@ -0,0 +1,38 @@ +from .colo_module import ColoModule +from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec + + +class ColoLinear(ColoModule): + + def __init__(self): + super(ColoLinear, self).__init__() + self._register_shard_params(['weight', 'bias']) + + def register(self, compute_pattern, pg: ProcessGroup): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D(pg) + + def _set_TP1D(self, pg): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([-1], [pg.tp_world_size()]), + 'bias': None + }, + mode='row', + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([0], [pg.tp_world_size()]), + 'bias': ShardSpec([0], [pg.tp_world_size()]) + }, + mode='col', + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/nn/parallel/layers/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38d128cc705e6bfa3db68cb6c59b66450ce1223e --- /dev/null +++ b/colossalai/nn/parallel/layers/module_utils.py @@ -0,0 +1,113 @@ +from typing import Dict +from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup +from colossalai.tensor import distspec +from . import ColoModule +import torch + +_COLOSSAL_MODULES: Dict[type, ColoModule] = {} + + +def register_colo_module(module_type: type, colo_module: ColoModule): + global _COLOSSAL_MODULES + _COLOSSAL_MODULES[module_type] = colo_module + + +def is_colo_module(module: torch.nn.Module): + global _COLOSSAL_MODULES + for module_type in _COLOSSAL_MODULES.keys(): + if isinstance(module, module_type): + return True + return False + + +def get_colo_module(module: torch.nn.Module): + global _COLOSSAL_MODULES + if is_colo_module(module): + for module_type, colo_module in _COLOSSAL_MODULES.items(): + if isinstance(module, module_type): + return colo_module + else: + return None + + +def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True): + if is_colo_module(module): + colo_module = get_colo_module(module) + param_names = colo_module.get_param_names() + compute_pattern = None + for param_name in param_names: + param = module.get_parameter(param_name) + if not isinstance(param, ColoParameter): + raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') + if param.has_compute_spec(): + cur_compute_pattern = param.compute_spec.compute_pattern + if compute_pattern is None: + compute_pattern = cur_compute_pattern + else: + if cur_compute_pattern != compute_pattern: + raise Exception( + f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') + else: + continue + + if compute_pattern is not None: + colo_module.register(compute_pattern, pg) + if not colo_module.has_compute_pattern(compute_pattern): + raise Exception( + f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') + + match_specs = False + allowed_specs = colo_module.get_dist_specs(compute_pattern) + for _, param_specs in allowed_specs.items(): + cur_match = True + for param_name, dist_spec in param_specs.items(): + param = module.get_parameter(param_name) + if param.has_compute_spec(): + if dist_spec != param.dist_spec: + cur_match = False + break + else: + if dist_spec is not None: + cur_match = False + break + if cur_match == True: + match_specs = True + break + if match_specs == False: + raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') + if recursive == True: + for submodule in module.children(): + check_colo_module(submodule, pg=pg, recursive=True) + + +def init_colo_module(module: torch.nn.Module, + compute_spec: ComputeSpec, + pg: ProcessGroup, + recursive=True, + mode='default'): + compute_pattern = compute_spec.compute_pattern + if is_colo_module(module): + # for each param + # set its process_group, dist_spec and compute_spec + colo_module = get_colo_module(module) + colo_module.register(compute_pattern, pg) + if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): + raise NotImplementedError + # a set for modules which update at least one param in the init process. + # these modules need to be checked whether all params still match one of the valid compute pattern. + modules_update_param = {module} + for param_name, dist_spec in colo_module.get_dist_specs_with_mode(compute_pattern, mode=mode).items(): + if dist_spec is None: + continue + param = module.get_parameter(param_name) + if isinstance(param, ColoParameter): + param.set_process_group(pg) + param.set_dist_spec(dist_spec) + param.compute_spec = compute_spec + for mod in param.shared_param_modules: + modules_update_param.add(mod) + for mod in modules_update_param: + check_colo_module(mod, pg, recursive=False) + if recursive == True: + for submodule in module.children(): + init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode) diff --git a/colossalai/nn/parallel/reducer.py b/colossalai/nn/parallel/reducer.py new file mode 100644 index 0000000000000000000000000000000000000000..5687055819fe1fd0177e507f1a27d3bab8b5b1b5 --- /dev/null +++ b/colossalai/nn/parallel/reducer.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + + +class Bucket: + + def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): + self.buffer = torch.zeros(size, dtype=dtype, device=device) + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + + def flush(self) -> None: + """Flush content of the bucket.""" + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + dist.all_reduce(self.buffer[:self.offset], group=self.group) + + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.offset = 0 + self.callbacks.clear() + self.buffer = torch.zeros_like(self.buffer) + + def alloc(self) -> None: + + if self.buffer.storage().size() == 0: + self.buffer.storage().resize_(self.buffer.numel()) + + def free(self) -> None: + + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + self.buffer.storage().resize_(0) + + def append(self, tensor: Tensor, callback_fn: Callable): + tensor_size = tensor.numel() + offset = self.offset + self.buffer[offset:offset + tensor_size].copy_(tensor.flatten()) + self.offset += tensor_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape) + self.callbacks.append(functools.partial(callback_fn, result_view)) + + @property + def avail_size(self) -> int: + return self.buffer.size(0) - self.offset + + +class Reducer: + + def __init__(self, bucket_size_mb: int = 25): + self.bucket_size_mb = bucket_size_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def all_reduce_async( + self, + tensor: Tensor, + group: ProcessGroup, + callback_fn: Optional[Callable] = None, + ) -> None: + bucket_size = self._get_bucket_size(tensor.element_size()) + + if tensor.numel() >= bucket_size: + dist.all_reduce(tensor, group=group) + if callback_fn is not None: + callback_fn(tensor) + return + + bucket = self._get_bucket(tensor, group) + if tensor.numel() > bucket.avail_size: + # not enough space remaining in bucket, flush it now + bucket.flush() + bucket.append(tensor, callback_fn) + + @torch.no_grad() + def flush(self) -> None: + for bucket in self.buckets.values(): + bucket.flush() + + @torch.no_grad() + def free(self) -> None: + for bucket in self.buckets.values(): + bucket.free() + + @functools.lru_cache() + def _get_bucket_size(self, element_size: int) -> int: + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_size_mb * MB / element_size + return int(bucket_size) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + bucket_size = self._get_bucket_size(tensor.element_size()) + self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group) + self.buckets[key].alloc() + return self.buckets[key] diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..844439cded1ea99585500f9fc3dd3b9a8c25642a --- /dev/null +++ b/colossalai/nn/parallel/utils.py @@ -0,0 +1,49 @@ +import torch +import torch.distributed as dist + +from colossalai.gemini.chunk import Chunk +from colossalai.tensor import ColoTensor +from colossalai.utils import get_current_device + + +def get_temp_total_chunk_on_cuda(chunk: Chunk): + if chunk.is_gathered: + return chunk.cuda_global_chunk + + if chunk.cuda_shard is not None: + shard_temp = chunk.cuda_shard + else: + shard_temp = chunk.cpu_shard.to(get_current_device()) + + total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) + gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) + dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) + + return total_temp + + +def _add_param(model, name, param): + name_list = name.split('.') + module = model._modules[name_list[0]] + for i in range(1, len(name_list) - 1): + module = module._modules[name_list[i]] + module._parameters[name_list[-1]] = param + + +def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module: + """convert_to_torch_module + + Args: + gemini_ddp_model (GeminiDDP): a gemini ddp model + + Returns: + torch.nn.Module: a torch model contains the params of gemini_ddp_model + """ + module = gemini_ddp_model.module + + for n, p in module.named_parameters(): + if isinstance(p, ColoTensor): + p.to_replicate_() + _add_param(module, n, p.data) + + return module diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fcde970764688c210d37de51964b5b081834666 --- /dev/null +++ b/colossalai/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .pipelinable import PipelinableContext, PipelinableModel +from .layer_spec import LayerSpec + +__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file diff --git a/colossalai/pipeline/layer_spec.py b/colossalai/pipeline/layer_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9169efff78bad7d30f42e9896dc3f61ecaf7fd --- /dev/null +++ b/colossalai/pipeline/layer_spec.py @@ -0,0 +1,55 @@ +import torch +from colossalai.utils.model.utils import call_to_str + +class LayerSpec: + """ + + """ + + def __init__(self, typename, *module_args, **module_kwargs): + self.typename = typename + self.module_args = module_args + self.module_kwargs = module_kwargs + self.children = None + self._param_count = 0 + + if not issubclass(typename, torch.nn.Module): + raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + + def __repr__(self): + return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) + + @property + def param_count(self): + return self._param_count + + def build(self): + """Build the stored specification.""" + + recovered_args = [] + for obj in self.module_args: + if isinstance(obj, LayerSpec): + obj = obj.build() + recovered_args.append(obj) + recovered_args = tuple(recovered_args) + + recovered_kwargs = {} + for k, v in self.module_kwargs.items(): + if isinstance(v, LayerSpec): + v = v.build() + recovered_kwargs[k] = v + + return self.typename(*recovered_args, **recovered_kwargs) + + def set_children(self, children): + self.children = children + + def count_params(self): + self._param_count = 0 + layer = self.build() + for param in layer.parameters(): + self._param_count += param.numel() + return self._param_count + + def reset_param_count(self): + self._param_count = 0 \ No newline at end of file diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79e19f9eaf771635852d8ffd747c06ea1209e110 --- /dev/null +++ b/colossalai/pipeline/middleware/__init__.py @@ -0,0 +1,3 @@ +from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal + +__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/pipeline/middleware/adaptor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..949700a2c49de505b37c408fd44283ef67f9569b --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/__init__.py @@ -0,0 +1,3 @@ +from .fx import get_topology as get_fx_topology + +__all__ = ['get_fx_topology'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/pipeline/middleware/adaptor/fx.py new file mode 100644 index 0000000000000000000000000000000000000000..8437c519476218dec90c968a47a73440ed71f519 --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/fx.py @@ -0,0 +1,145 @@ +from torch.fx.graph_module import GraphModule +from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo +import torch + +def partition_name_to_id(partition_name, is_input=False, is_output=False): + if is_input: + partition_id = 0 + elif is_output: + partition_id = 1 + else: + prefix = 'submod_' + partition_id = int(partition_name.split(prefix)[-1]) + 2 + return partition_id + +# There are two kinds of def in fx.graph +# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. +# e.g. submod1 = call_module(...) +# temporary_val = submod1[0] +# submod2 = call_module(temporary_val, ...) +# 2. direct_use & direct_def, which means the output is used by next partition directly. +# e.g. submod1 = call_module(...) +# submod2 = call_module(submod1, ...) +def find_input_in_partition(node, partitions, input_partitions=None): + p_input_val = None + direct_def = not node.name.startswith('getitem') + # search in input + if direct_def and input_partitions is not None: + partition_id = partition_name_to_id('', is_input=True) + for i, input_node in enumerate(input_partitions): + if input_node == node: + p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) + return p_input_val + # search submod in mid part + if direct_def: + for partition in partitions: + if partition == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=0) + return p_input_val + # search temporary value in graph + else: + for partition in partitions: + for offset, mid_val in enumerate(partition.users): + if mid_val == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) + return p_input_val + + return p_input_val + +def find_output_in_partition(node, partitions, output_partitions=None): + p_output_val = PartitionOutputVal() + for user in node.users: + direct_use = not user.name.startswith('getitem') + # user is mid partition + for partition in partitions: + # direct call + if direct_use: + if user == partition: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + # getitem call + else: + if user in partition.args: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == user: + p_output_val.add(partition_id=partition_id, offset=i) + break + + # user is output + if output_partitions is not None: + output_node = output_partitions[0] + if user.op == output_node.op: + output_keys = {} + partition_id = partition_name_to_id('', is_output=True) + torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) + for i, arg in enumerate(output_keys): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + return p_output_val + +def get_topology(gm: GraphModule): + topo = Topo() + topo_output_partition = Partition() + + input_partitions = [] + partitions = [] + output_partitions = [] + for node in gm.graph.nodes: + if node.op == 'placeholder': + input_partitions.append(node) + elif node.name.startswith('submod_'): + partitions.append(node) + elif node.op == 'output': + output_partitions.append(node) + else: + continue + + # set output for input_partition + topo_input_partition = Partition() + for partition in input_partitions: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_input_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=0, partition=topo_input_partition) + topo.set_input_partition_id(partition_id=0) + + for i, partition in enumerate(partitions): + topo_mid_partition = Partition() + # set input for submodule + for arg in partition.args: + cur_node = arg + p_input_val = find_input_in_partition(cur_node, partitions, input_partitions) + topo_mid_partition.add_input_val(p_input_val) + # set output for submodule + direct_use = True + for user in partition.users: + if user.name.startswith('getitem'): + direct_use = False + break + if direct_use: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + else: + for user in partition.users: + cur_node = user + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) + + # set input for output_partition + for partition in output_partitions: + topo_output_partition = Partition() + torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( + find_input_in_partition(n, partitions, input_partitions))) + topo.set_partitions(partition_id=1, partition=topo_output_partition) + topo.set_output_partition_id(partition_id=1) + + return topo \ No newline at end of file diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/pipeline/middleware/topo.py new file mode 100644 index 0000000000000000000000000000000000000000..e798e2ed9cab0cd3036b2dc39169c2f81470bae8 --- /dev/null +++ b/colossalai/pipeline/middleware/topo.py @@ -0,0 +1,206 @@ +from typing import Dict, List +from dataclasses import dataclass + +# This file includes data structure used by Pipeline Middleware. + +@dataclass +class ValPosition: + partition_id: int + offset: int + + def __str__(self) -> str: + res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionInputVal(object): + def __init__(self, partition_id, offset) -> None: + # every input from which partition_id and which offset + val_pos = ValPosition(partition_id, offset) + self._from_partition_and_offset: ValPosition = val_pos + + def get(self): + return self._from_partition_and_offset + + def __str__(self) -> str: + res = '' + res += f'<-({self._from_partition_and_offset})' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionOutputVal(object): + def __init__(self) -> None: + # every output to which partition_id and which offset + self._to_partition_and_offset: List[ValPosition] = [] + + def add(self, partition_id, offset): + val_pos = ValPosition(partition_id, offset) + self._to_partition_and_offset.append(val_pos) + + def get(self): + return self._to_partition_and_offset + + def __str__(self) -> str: + res = '' + res += '->(' + for val_pos in self._to_partition_and_offset: + res += f'{val_pos},' + res += ')' + return res + + def __repr__(self) -> str: + return self.__str__() + +class Partition(object): + def __init__(self) -> None: + self._input_vals: List[PartitionInputVal] = [] + self._output_vals: List[PartitionOutputVal] = [] + + def add_input_val(self, input_val: PartitionInputVal): + self._input_vals.append(input_val) + + def add_output_val(self, output_val: PartitionOutputVal): + self._output_vals.append(output_val) + + def get_input_vals(self): + return self._input_vals + + def get_output_vals(self): + return self._output_vals + + # get the output offsets sent to dst_partition_id + def get_output_offsets(self, dst_partition_id): + res = [] + for offset, output_val in enumerate(self._output_vals): + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id == dst_partition_id: + res.append(offset) + + return res + + # get all input dst partition_ids + def get_input_partition_ids(self): + res = [] + for input_val in self._input_vals: + val_pos = input_val.get() + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + # get all output dst partition_ids + def get_output_partition_ids(self): + res = [] + for output_val in self._output_vals: + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + def __str__(self) -> str: + res = '' + res += f' input:\n' + res += f' length:{len(self._input_vals)}\n' + for i, input_val in enumerate(self._input_vals): + res += f' offset={i}:{input_val}\n' + + res += f' output:\n' + res += f' length:{len(self._output_vals)}\n' + for i, output_val in enumerate(self._output_vals): + res += f' offset={i}:{output_val}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + +# This class is a middleware between partition splitter +# and Pipeline Scheduler. It records the graph info about +# partition input/output and provides it to scheduler. +# There are three kinds of partition in Pipeline Middleware Design +# which represents the whole process of a model execution: input-fwd-output +# 1. input_partition: records the input of a model. +# 2. mid_partition: record the splitted forwards execution of a model. +# 3. output_partition: records the output of a model. +# attributes: +# _partitions: include all partitions +# _input_partition_id: the key represents input_partition +# _output_partition_id: the key represents output_partition +class Topo(object): + def __init__(self, input_partition_id=None, output_partition_id=None) -> None: + self._partitions: Dict[int, Partition] = {} + self._input_partition_id = input_partition_id + self._output_partition_id = output_partition_id + + def set_input_partition_id(self, partition_id: int): + self._input_partition_id = partition_id + + def set_output_partition_id(self, partition_id: int): + self._output_partition_id = partition_id + + def get_input_partition_id(self): + return self._input_partition_id + + def get_output_partition_id(self): + return self._output_partition_id + + def set_partitions(self, partition_id: int, partition: Partition): + self._partitions[partition_id] = partition + + def get_mid_partitions(self): + res = {} #{partition_id: Partition} + for partition_id, partition in self._partitions.items(): + if self._input_partition_id == partition_id or self._output_partition_id == partition_id: + continue + res[partition_id] = partition + return res + + def get_mid_partition_ids(self): + return list(self.get_mid_partitions().keys()) + + def get_input_partition(self): + if self._input_partition_id is not None: + return self._partitions[self._input_partition_id] + return None + + def get_output_partition(self): + if self._output_partition_id is not None: + return self._partitions[self._output_partition_id] + return None + + def get_partition_by_id(self, partition_id): + return self._partitions[partition_id] + + def __str__(self) -> str: + res = '' + if len(self._partitions) == 0: + return 'Empty Topo Graph.' + + input_part = self.get_input_partition() + if input_part is not None: + res += '{\n' + res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' + res += '}\n' + + mid_parts = self.get_mid_partitions() + for i, (partition_id, part) in enumerate(mid_parts.items()): + res += '{\n' + res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' + res += '}\n' + + output_part = self.get_output_partition() + if output_part is not None: + res += '{\n' + res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' + res += '}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + \ No newline at end of file diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py new file mode 100644 index 0000000000000000000000000000000000000000..9731530a6e15755c9d152697fec0ae0cfd102328 --- /dev/null +++ b/colossalai/pipeline/pipelinable.py @@ -0,0 +1,254 @@ +import torch +import inspect +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses + +from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \ + build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \ + call_module, customized_partition +from colossalai.nn.layer.utils import CheckpointModule +from colossalai.tensor import ColoParameter +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from .layer_spec import LayerSpec + + +class PipelinableContext(InsertPostInitMethodToModuleSubClasses): + """ + A context manager to split the model into pipeline stages. + """ + + def __init__(self, policy: str = "balanced"): + super().__init__() + self._layer_spec_dict = {} + self._root_children = None + self._model = None + self._layer_spec_list = [] + self._func_dict = {} + self._policy = policy + + @property + def policy(self): + return self._policy + + @policy.setter + def policy(self, policy: str): + self._policy = policy + + @property + def layers_count(self): + return len(self._layer_spec_list) + + @property + def funcs_count(self): + return len(self._func_dict) + + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + # reserve rng states + self.cpu_rng_state = torch.get_rng_state() + self.cuda_rng_state = torch.cuda.get_rng_state() + + def _post_context_exec(self): + """ + The callback function when exiting context. + """ + + # reset rng states + torch.set_rng_state(self.cpu_rng_state) + torch.cuda.set_rng_state(self.cuda_rng_state) + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + NOTE() The module may be passed to this function multiple times. + """ + # iterate over the positional arguments + # to check if an argument is a torch Module + # if found any torch Module, replace it with its layer spec + # for storage purpose + modified_args = [] + for arg in args: + if isinstance(arg, torch.nn.Module): + # if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build. + # if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context. + if id(arg) in self._layer_spec_dict: + arg = self._layer_spec_dict[id(arg)] + + modified_args.append(arg) + + # to the same for the keyword arguments + modified_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, torch.nn.Module): + v = self._layer_spec_dict[id(v)] + # (lyl)TODO: analyse ColoTensor as well + modified_kwargs[k] = v + + # keep track of the module children + # as torch.nn.Module.__init__ is called from inner module to outer module, + # the final value of self._model will be the outermost model + # e.g. if the model is torchvision.models.resnet18, then the final value of self._model + # will be the ``ResNet`` object. + self._root_children = list(module.children()) + self._model = module + + # store the children to keep the module hierarchy + layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs) + layer_spec.set_children(module.children()) + + # store the layer spec in this context + module_id = id(module) + self._layer_spec_dict[module_id] = layer_spec + + # convert all torch.nn.Parameter to colossalai.tensor.ColoParameter + name_list = [] + for name, param in module.named_parameters(): + if isinstance(param, ColoParameter): + continue + name_list.append((name, param)) + + for name, param in name_list: + if hasattr(module, name): + delattr(module, name) + setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad)) + + def to_layer_list(self, exec_seq=None): + """ + Create a layer spec list and func list with execution sequence given by user. + If exec_seq is None, we will take the module initizing order as execution order. + """ + + self._exec_seq = exec_seq + if exec_seq is None: + # if user do not provide the model executing sequence, we use the initialization order as the executing order. + children_name = [] + for child in self._root_children: + layer_spec = self._layer_spec_dict[id(child)] + if layer_spec.typename in (torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential): + for child_in_container in layer_spec.children: + self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) + for name, module in self._model.named_modules(): + if id(module) == id(child_in_container): + children_name.append(name) + break + else: + self._layer_spec_list.append(layer_spec) + for name, module in self._model.named_modules(): + if id(module) == id(child): + children_name.append(name) + break + + else: + front_funcs_list = [] + named_modules = dict(self._model.named_modules()) + for index, element in enumerate(exec_seq): + if isinstance(element, str): + if element == 'SPLIT_NODE': + continue + assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' + + # get the layer spec based on the module ID + module = named_modules[element] + layer_spec = self._layer_spec_dict[id(module)] + + # check whether there are functions which should be executed before this module + if len(front_funcs_list) != 0: + func_key = (layer_spec, "front") + if func_key not in self._func_dict: + self._func_dict[func_key] = [] + for f in front_funcs_list: + self._func_dict[func_key].append(f) + front_funcs_list = [] + + func_key = (layer_spec, "behind") + self._layer_spec_list.append(layer_spec) + elif isinstance(element, tuple) and element[1] == "front": + front_funcs_list.append(element[0]) + else: + if func_key not in self._func_dict: + self._func_dict[func_key] = [] + if isinstance(element, tuple): + self._func_dict[func_key].append(element[0]) + else: + self._func_dict[func_key].append(element) + + def partition(self, num_chunks, pipeline_size, rank): + """ + Partitioned model will be built respect to partion policy. + The real module instance will be built in this method. + """ + if isinstance(self._policy, str): + if self._policy == "uniform": + parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank] + elif self._policy == "balanced": + param_counts = [] + for layer_spec in self._layer_spec_list: + param_counts.append(layer_spec.count_params()) + parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] + elif self._policy == "customized": + assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.' + self.customized_parts = customized_partition(self._exec_seq) + assert len(self.customized_parts) == gpc.get_world_size( + ParallelMode.PIPELINE + ), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partions is {len(self.customized_parts)}' + parts = self.customized_parts[rank] + else: + raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].") + elif isinstance(self._policy, dict): + parts = self._policy[rank] + else: + raise ValueError("A partition policy should be either a string or a dictionary.") + + layers_to_build = [] + for start, end in parts: + layers_to_build += self._layer_spec_list[start:end] + behind_func_dict_in_partition = {} + front_func_dict_in_partition = {} + module_list_in_partition = [] + for layer in layers_to_build: + module = layer.build() + module_list_in_partition.append(module) + if (layer, "front") in self._func_dict: + front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")] + elif (layer, "behind") in self._func_dict: + behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] + module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) + pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, + behind_func_dict_in_partition) + + return pipeline_model + + +class PipelinableModel(torch.nn.Module): + + def __init__(self, module_list, front_func_dict, behind_func_dict): + super().__init__() + self._module_list = module_list + self._front_func_dict = front_func_dict + self._behind_func_dict = behind_func_dict + + def forward(self, *input_tensor, **kwargs): + for module in self._module_list: + + if id(module) in self._front_func_dict: + input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) + + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs) + if input_tensor is None: + input_tensor = call_module(module, kwargs=module_kwargs) + elif isinstance(input_tensor, torch.Tensor): + input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs) + else: + input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs) + + if id(module) in self._behind_func_dict: + input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) + + return input_tensor diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/pipeline/pipeline_process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..c61d97ebabfa354ced39b952708188287d56cab3 --- /dev/null +++ b/colossalai/pipeline/pipeline_process_group.py @@ -0,0 +1,168 @@ +from typing import List, Dict, Tuple +import os +import threading + +from torch.distributed import rpc +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup + + +class PipelineProcessGroup: + # TODO : flexible API for DP size and TP size + # In the future design mode, dp_degree and tp_degree should be removed + def __init__(self) -> None: + self.is_initialize = False + + def set_global_info(self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda") -> None: + + device_mesh_size = dp_degree * tp_degree + assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" + self._num_worker_threads = num_worker_threads + + self._device_mesh_size = device_mesh_size + self._rank = rank + self._world_size = world_size + self._dp_degree = dp_degree + self._tp_degree = tp_degree + self.device = device + self._stage_num = world_size // device_mesh_size + self._pp_rank = rank // device_mesh_size + self._pp_ranks = [(rank % device_mesh_size) + i * device_mesh_size for i in range(self._stage_num)] + self._local_stage_ranks = [(rank // device_mesh_size * device_mesh_size) + i for i in range(device_mesh_size)] + + # pp_ranks + self._initialize_pp_process_group() + + # initialise tp dp process groups + self._initialize_tp_dp_process_group() + + # status + self._is_first_pp_rank = self._pp_rank == 0 + self._is_last_pp_rank = self._pp_rank == self._stage_num - 1 + + self.is_initialize = True + + # lock + self.initialise_lock = threading.Lock() + self.chimera_lock = threading.Lock() + + def _initialize_process_group(self): + stage_num = self.get_stage_num() + if stage_num == 1: + return + device = self.device + world_size = self.get_world_size() + rank = self.get_global_rank() + backend = 'nccl' if device == 'cuda' else 'gloo' + dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group') + + def _initialize_pp_process_group(self) -> None: + rank = self.get_global_rank() + world_size = self.get_world_size() + + # build rpc connection + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads) + + for pp_rank in self._pp_ranks: + options.set_device_map(f'work{pp_rank}', {rank: pp_rank}) + + rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + + def _initialize_tp_dp_process_group(self) -> None: + rank = self.get_global_rank() + local_stage_ranks = self.get_local_stage_global_ranks() + dp_degree = self.get_dp_degree() + tp_degree = self.get_tp_degree() + self._tp_dp_process_group = ProcessGroup(rank, local_stage_ranks, tp_degree, dp_degree) + + def get_global_rank(self): + return self._rank + + def get_world_size(self): + return self._world_size + + def get_dp_degree(self) -> int: + return self._dp_degree + + def get_tp_degree(self) -> int: + return self._tp_degree + + def get_local_device_mesh_size(self) -> int: + return self._device_mesh_size + + def get_device_mesh_num(self) -> int: + pass + + def get_stage_num(self) -> int: + return self._stage_num + + def is_first_stage(self) -> bool: + return self._is_first_pp_rank + + def is_last_stage(self) -> bool: + return self._is_last_pp_rank + + def check_pp_rank_valid(self, pp_rank: int) -> bool: + return -1 < pp_rank < self._stage_num + + def get_local_pp_rank(self) -> int: + return self._pp_rank + + def get_prev_pp_rank(self) -> int: + prev_pp_rank = self._pp_rank - 1 + if not self.check_pp_rank_valid(prev_pp_rank): + assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a previous stage!") + return prev_pp_rank + + def get_next_pp_rank(self) -> int: + next_pp_rank = self._pp_rank + 1 + if not self.check_pp_rank_valid(next_pp_rank): + assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a next stage!") + return next_pp_rank + + def get_local_stage_global_ranks(self) -> List[int]: + return self._local_stage_ranks + + def local_dp_rank(self) -> int: + return self._tp_dp_process_group.dp_local_rank() + + def local_tp_rank(self) -> int: + return self._tp_dp_process_group.tp_local_rank() + + def get_pp_global_ranks(self) -> int: + return self._pp_ranks + + def get_dp_global_ranks(self): + pass + + def get_tp_global_ranks(self): + pass + + def get_chimera_all_reduce_group(self, pp_rank: int): + with self.chimera_lock: + if not hasattr(self, 'chimera_groups'): + world_size = self.get_world_size() + stage_num = self.get_stage_num() + assert world_size % 2 == 0, 'world_size must be even in chimera!' + self.chimera_groups = {} + for rank in range(world_size // 2): + pair = [rank, world_size - 1 - rank] + group = dist.new_group(pair) + self.chimera_groups[pair[0]] = group + self.chimera_groups[pair[1]] = group + self.chimera_groups[pair[0] + stage_num] = group + self.chimera_groups[pair[1] + stage_num] = group + self.chimera_step_lock = threading.Lock() + self.chimera_step_lock.acquire() + + return self.chimera_groups[pp_rank] + + +ppg = PipelineProcessGroup() diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9e9d44f46c5aea6c0d211d24af21ed834a2503 --- /dev/null +++ b/colossalai/pipeline/rpc/__init__.py @@ -0,0 +1,4 @@ +from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine +from .utils import pytree_map + +__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] \ No newline at end of file diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1cbb0c4fb71804970b97818fee9a2179b0d31e --- /dev/null +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -0,0 +1,1201 @@ +import inspect +import math +import threading +from abc import ABC, abstractmethod +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.distributed.rpc as rpc +from torch import autograd, nn, optim +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future + +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc.utils import ( + get_batch_lengths, + pytree_filter, + pytree_map, + split_batch, + tensor_shape_list, + type_detail, +) + + +class Phase(Enum): + FORWARD = 0 + BACKWARD = 1 + UPDATE = 2 + INPUT = 3 + + +class UniqueKey: + __slots__ = ('microbatch_id', 'phase') + microbatch_id: int + phase: Phase + + def __init__(self, microbatch_id, phase) -> None: + self.microbatch_id = microbatch_id + self.phase = phase + + def __eq__(self, __o: object) -> bool: + return (self.microbatch_id == __o.microbatch_id) and (self.phase == __o.phase) + + def __hash__(self) -> int: + return tuple.__hash__((self.microbatch_id, self.phase)) + + def __repr__(self) -> str: + return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})' + + +class WorkItem: + __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', + 'num_microbatches', 'forward_only') + + stage_id: int + phase: Phase + args: Tuple[Any] + kwargs: Dict[str, Any] + output: Future + microbatch_id: int + refcount: int + batch_id: int + num_microbatches: int + forward_only: bool + + def __init__(self, + stage_id, + phase, + args, + kwargs, + output, + microbatch_id, + batch_id, + num_microbatches, + forward_only, + refcount=0) -> None: + for attr_name in self.__slots__: + setattr(self, attr_name, locals()[attr_name]) + + +class BackwardCache: + __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs') + checkpoint: bool + stage_input_args: Tuple[Any] + stage_input_kwargs: Dict[Any, Any] + stage_outputs: Tuple[Any] + + def __init__(self, + stage_input_args: Tuple[Any], + stage_input_kwargs: Dict[Any, Any] = None, + stage_outputs: Tuple[Any] = None, + checkpoint: bool = False) -> None: + for arg_name in self.__slots__: + setattr(self, arg_name, locals()[arg_name]) + + +class WorkerBase(ABC): + + def __init__(self, + partition_fn: Callable, + partition_args: tuple, + pp_rank: int, + actual_stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + super().__init__() + + self.pp_rank = pp_rank + self.actual_stage_num = actual_stage_num + self.num_microbatches = num_microbatches + self.checkpoint = checkpoint + + if data_process_func is not None: + self.data_process_func = partial(data_process_func, pp_rank) + + self.device = device + self._initialize_outstanding_range() + + # variable and const for context managment + self.outstanding = 0 + self.forward_times = 0 + self.backward_times = 0 + self.reset_key = UniqueKey(0, Phase.FORWARD) + + # rref of other workers + self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None + + # lock for the list + self._initialize_lock() + + # topology info + self.producer_stage_ids: List[int] = None + self.consumer_stage_ids: List[int] = None + + # module partitions + self.partition_fn = partition_fn + self.partition_args = partition_args + self.criterion = criterion + self.metric = metric + self.reset = False + + # context to maintain loop + self._initialize_context_container() + + # main loop + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) + self.main_loop_thread.start() + + def _get_future_by_device(self): + return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + + def _initialize_outstanding_range(self): + outstanding_range = None + if self.pp_rank == self.actual_stage_num - 1: + outstanding_range = (0, 1) + else: + outstanding_range = (self.actual_stage_num, self.actual_stage_num) + self.outstanding_range = outstanding_range + + def _initialize_context_container(self): + self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() + self.microbatch_id_to_labels: Dict[int, Any] = dict() + self.work_list: Dict[UniqueKey, WorkItem] = dict() + self.output_list: Dict[UniqueKey, WorkItem] = dict() + + def _initialize_lock(self): + self.partition_condition_lock = threading.Condition(threading.Lock()) + self.work_list_condition_lock = threading.Condition(threading.Lock()) + self.output_list_condition_lock = threading.Condition(threading.Lock()) + self.label_lock = threading.Condition(threading.Lock()) + self.reset_condition = threading.Condition(threading.Lock()) + + def _initialize_partition(self): + partition_fn = self.partition_fn + partition_args = self.partition_args + device = self.device + with self.partition_condition_lock: + self.module_partition: nn.Module = partition_fn(*partition_args).to(device) + self.partition_condition_lock.notify_all() + + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: + assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" + assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" + self.pp_rank_to_worker_rref = pp_rank_to_worker_rref + + # for some schedule need the other worker's info to initialise partition (like Chimera) + # construction of partition is executed after the registion of pp_rank_to_worker_rref + self._initialize_partition() + + # res_use works for lifecycle counter, + # if ref_use is True, lifecycle won't add. + def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any: + with self.output_list_condition_lock: + self.output_list_condition_lock.wait_for(lambda: key in self.output_list) + output_work_item = self.output_list[key] + self.output_list.pop(key) + + if not ref_use: + output_work_item.refcount += 1 + refcount = output_work_item.refcount + output = output_work_item.output + + if output_work_item.phase == Phase.FORWARD: + # lifecycle management for DAG scheduler + lifecycle = len(self.get_consumer_stage_ids()) + if self.is_model_output(): # an extra reference for scheduler collecting results + lifecycle += 1 + with self.output_list_condition_lock: + # all consumers have been satisfied, the work_item can be released + # or put it into work list again. + if refcount < lifecycle: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + elif output_work_item.phase == Phase.BACKWARD: + lifecycle = len(self.get_producer_stage_ids()) + if self._is_last_step(output_work_item): + lifecycle += 1 # an extra reference for scheduler collecting results + with self.output_list_condition_lock: + # all producers have been satisfied, the work_item can be released + # or put it into work list again. + if refcount < lifecycle: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + else: + with self.output_list_condition_lock: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + + if isinstance(output, Future): + output = output.wait() + + return output + + def get_parameters(self) -> List[torch.Tensor]: + return [p for p in self.module_partition.parameters()] + + def get_parameter_gradients(self) -> List[torch.Tensor]: + return [p.grad for p in self.module_partition.parameters()] + + def get_partition(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + return self.module_partition + + def get_partition_state_dict(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + return self.module_partition.state_dict() + + def _make_args_kwargs(self, microbatch, merge=False): + if isinstance(microbatch, dict): + if merge: + return list(microbatch.values()), {} + return [], microbatch + elif isinstance(microbatch, torch.Tensor): + return [microbatch], {} + elif isinstance(microbatch, (tuple, list)): + args = [] + kwargs = {} + for arg in microbatch: + if isinstance(arg, dict): + kwargs.update(arg) + else: + args.append(arg) + if merge: + arg_lst = args + for arg in kwargs.values(): + arg_lst.append(arg) + return arg_lst, {} + return args, kwargs + else: + raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}") + + # just for first pp_rank + def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): + key = UniqueKey(microbatch_id, Phase.FORWARD) + output = self._get_future_by_device() + + if not self.use_middleware(): + # make args and kwargs + args, kwargs = self._make_args_kwargs(microbatch) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + else: + # make args and kwargs + arg_lst, _ = self._make_args_kwargs(microbatch, merge=True) + + # first stage assign correct input into other stages + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + input_partition = topo.get_input_partition() + self_input_offsets = input_partition.get_output_offsets(self_partition_id) + recv_input_key = UniqueKey(microbatch_id, Phase.INPUT) + + # set input for self rank + self_arg_lst = [] + for off in self_input_offsets: + self_arg_lst.append(arg_lst[off]) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + # put input tensor which other nodes need into output_list as Phase.INPUT + work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, + self.num_microbatches, forward_only) + + with self.output_list_condition_lock: + self.output_list[recv_input_key] = work_item_remote + self.output_list_condition_lock.notify_all() + + # just for last pp_rank + def set_labels(self, microbatch_id: int, microlabels: Any): + with self.label_lock: + self.microbatch_id_to_labels[microbatch_id] = microlabels + self.label_lock.notify_all() + + # just for last pp_rank + def _begin_backward(self, microbatch_id: int): + with self.work_list_condition_lock: + assert self.producer_stage_ids is not None + + key = UniqueKey(microbatch_id, Phase.BACKWARD) + output = self._get_future_by_device() + grad_wrt_loss = None + + work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, + self.num_microbatches, False) + + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + def _subscribe_producer(self, microbatch_id: int, forward_only: bool): + """ + You should call this function asynchronously + """ + stage_id = self.pp_rank + output = self._get_future_by_device() + if not self.use_middleware(): + producer_num = len(self.producer_stage_ids) + subscribe_forward_futures: List[Future] = [None] * producer_num + for i in range(producer_num): + producer_stage_id = self.producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + else: + producer_stage_ids = self.get_producer_stage_ids() + producer_num = len(producer_stage_ids) + if self.need_model_input(): + producer_num += 1 # for input partition + subscribe_forward_futures: List[Future] = [None] * producer_num + + # TODO(jiangziyue) get single value instead of the whole output + if self.need_model_input(): + producer_stage_id = 0 + producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + + for i in range(0, producer_num - 1): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key) + + else: + for i in range(producer_num): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key) + + work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, + microbatch_id, None, self.num_microbatches, forward_only) + + return work_item_from_producer + + # TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one. + def subscribe_producer(self, microbatch_id: int, forward_only: bool): + key = UniqueKey(microbatch_id, Phase.FORWARD) + with self.work_list_condition_lock: + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only) + self.work_list[key] = work_item_from_producer + self.work_list_condition_lock.notify_all() + + def _subscribe_consumer(self, microbatch_id: int): + """ + You should call this function asynchronously + """ + stage_id = self.pp_rank + output = self._get_future_by_device() + if not self.use_middleware(): + consumer_stage_ids = self.consumer_stage_ids + else: + consumer_stage_ids = self.get_consumer_stage_ids() + consumer_num = len(consumer_stage_ids) + subscribe_backward_futures: List[Future] = [None] * consumer_num + for i in range(consumer_num): + consumer_stage_id = consumer_stage_ids[i] + consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) + consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] + subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) + + # flatten args + work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, + microbatch_id, None, self.num_microbatches, False) + + return work_item_from_consumer + + def subscribe_consumer(self, microbatch_id: int): + key = UniqueKey(microbatch_id, Phase.BACKWARD) + with self.work_list_condition_lock: + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_consumer = self._subscribe_consumer(microbatch_id) + self.work_list[key] = work_item_from_consumer + self.work_list_condition_lock.notify_all() + + def get_producer_stage_ids(self): + producer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + prev_rank = rank - 1 + if prev_rank >= 0: + producer_stage_ids.append(prev_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + input_partition_ids = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + for partition_id in input_partition_ids: + # ignore input partition in current implementation. + # it will be specially tackled. + if partition_id != model_input_partition_id: + producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return producer_stage_ids + + def get_consumer_stage_ids(self): + consumer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + next_rank = rank + 1 + if next_rank <= self.actual_stage_num - 1: + consumer_stage_ids.append(next_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_partition_ids = self_partition.get_output_partition_ids() + model_output_partition_id = topo.get_output_partition_id() + for partition_id in output_partition_ids: + if model_output_partition_id != partition_id: + consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return consumer_stage_ids + + def _get_producer_consumer(self) -> None: + rank = self.pp_rank + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be aranged in order, the order of the input of current forward + self.producer_stage_ids = self.get_producer_stage_ids() + self.consumer_stage_ids = self.get_consumer_stage_ids() + + def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + return partition_ids[pp_rank] + + def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + for i, id in enumerate(partition_ids): + if id == partition_id: + return i + + def get_topo(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + if hasattr(self.module_partition, '_topo'): + return self.module_partition._topo + else: + return None + + def use_middleware(self): + topo = self.get_topo() + return topo is not None + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_fwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + if self.is_first_stage(): + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + # TODO get by offset + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if output_len == 1: + target = args_or_kwargs[src_index] + else: + target = args_or_kwargs[src_index][src_offset] + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_bwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + flatten_args = [] + # TODO get by offset + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if input_len == 1: + part_grad = args_or_kwargs[dst_index] + else: + part_grad = args_or_kwargs[dst_index][dst_offset] + + if target is None: + target = part_grad + elif part_grad is not None: + target += part_grad + else: + continue + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs + + @abstractmethod + def _get_work_item_key(self) -> UniqueKey: + """ + this method control the order of the microbatch to consume + """ + + def is_first_stage(self): + return self.pp_rank == 0 + + def is_last_stage(self): + return self.pp_rank == self.actual_stage_num - 1 + + def need_model_input(self): + need_input = False + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition = topo.get_partition_by_id(self_partition_id) + partition_inputs = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + if model_input_partition_id in partition_inputs: + need_input = True + return not self.is_first_stage() and need_input + + def is_model_output(self): + return self.is_last_stage() + + def is_model_input(self): + return self.is_first_stage() + + def _default_data_process_func(self, args_kwargs): + if self.is_first_stage(): + args = args_kwargs[0] + kwargs = args_kwargs[1] + else: + args = args_kwargs + kwargs = {} + + return args, kwargs + + def _consume_work_item_by_phase(self, work_item: WorkItem): + phase = work_item.phase + args = work_item.args + kwargs = work_item.kwargs + microbatch_id = work_item.microbatch_id + forward_only = work_item.forward_only + data_process_func = getattr(self, 'data_process_func', self._default_data_process_func) + consume_result = None + + is_first_stage = self.is_first_stage() + is_last_stage = self.is_last_stage() + + if phase == Phase.FORWARD: + # remind its consumer to get data before forward + if not is_last_stage: + for stage_id in self.consumer_stage_ids: + consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id] + consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only) + + # sustain pipeline context + self.forward_times += 1 + if not forward_only: + self.outstanding += 1 + + # parse and integrate args and kwargs + if is_first_stage: + args = self._get_real_args_kwargs_fwd(args) + kwargs = self._get_real_args_kwargs_fwd(kwargs) + args_kwargs = (args, kwargs) + else: + args_kwargs = self._get_real_args_kwargs_fwd(args) + + if not forward_only: + pytree_map(args_kwargs, + lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False), + process_types=torch.Tensor) + + args, kwargs = data_process_func(args_kwargs) + + stage_outputs = None + stage_input_args = args + stage_input_kwargs = kwargs + use_checkpoint = None + + if forward_only: + with torch.no_grad(): + consume_result = self.module_partition(*args, **kwargs) + + if is_last_stage and self.criterion: + with self.label_lock: + self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) + labels = self.microbatch_id_to_labels.pop(microbatch_id) + loss: torch.Tensor = self.criterion(consume_result, labels) + if self.metric is not None: + metric_result = self.metric(consume_result, labels) + if isinstance(metric_result, torch.Tensor): + metric_result = metric_result.item() + else: + metric_result = None + consume_result = [loss.item(), metric_result] + + # last stage doesn't need to do checkpoint, for it will do backward instantly + stage_input_args = None + stage_input_kwargs = None + stage_outputs = consume_result + + elif self.checkpoint and not is_last_stage: + with torch.no_grad(): + consume_result = self.module_partition(*args, **kwargs) + + stage_outputs = consume_result + use_checkpoint = True + + else: + consume_result = self.module_partition(*args, **kwargs) + + if is_last_stage and self.criterion: + with self.label_lock: + self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) + labels = self.microbatch_id_to_labels.pop(microbatch_id) + loss: torch.Tensor = self.criterion(consume_result, labels) + if self.metric is not None: + metric_result = self.metric(consume_result, labels) + if isinstance(metric_result, torch.Tensor): + metric_result = metric_result.item() + else: + metric_result = None + + consume_result = [loss.item(), metric_result] + else: + loss = consume_result + + stage_outputs = loss + use_checkpoint = False + + if not forward_only: + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args, + stage_input_kwargs, + stage_outputs, + checkpoint=use_checkpoint) + # if not forward_only, do the backward + if not forward_only: + if is_last_stage: # if it is the last stage, trigger backward automatic + self._begin_backward(microbatch_id) + + elif phase == Phase.BACKWARD: + # remind its producer to get data before backward + if not is_first_stage: + for stage_id in self.producer_stage_ids: + producer_worker_rref = self.pp_rank_to_worker_rref[stage_id] + producer_worker_rref.remote().subscribe_consumer(microbatch_id) + self.backward_times += 1 + self.outstanding -= 1 + + assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" + backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) + + stage_outputs = backward_cache.stage_outputs + stage_input_args = backward_cache.stage_input_args + stage_input_kwargs = backward_cache.stage_input_kwargs + use_checkpoint = backward_cache.checkpoint + + if use_checkpoint: + stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)] + + # overlap recompute and future.wait + if not is_last_stage: + grad_tensors = self._get_real_args_kwargs_bwd(args) + else: + grad_tensors = None + + # take tensor only (for only tensor can do backward) + # TODO(jiangziyue) : All values which should do bp are torch.Tensor? + stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor) + grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor) + + # output all input's grad to producer, even it has no grad(output None) + # to make the offset aligned to the topo's record. + if grad_tensors is not None: + filtered_outputs = [] + filtered_grads = [] + for i, grad in enumerate(grad_tensors): + stage_output = stage_outputs[i] + if stage_output.requires_grad and grad is not None: + filtered_outputs.append(stage_output) + filtered_grads.append(grad) + + stage_outputs = filtered_outputs + grad_tensors = filtered_grads + + autograd.backward(stage_outputs, grad_tensors=grad_tensors) + + # collect grad of input tensor + consume_result = [] + if not is_first_stage: + # In current design, input mush be a flatten args. + for arg in stage_input_args: + if isinstance(arg, torch.Tensor): + consume_result.append(arg.grad) + else: + consume_result.append(None) + + else: + raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") + + return consume_result + + def _get_store_len(self): + return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}' + + def _get_parameter_grad_sum(self): + grad_sum = 0 + for p in self.module_partition.parameters(): + if p.grad is not None: + grad_sum += p.grad.sum() + return grad_sum + + def _is_first_step(self, work_item: WorkItem) -> bool: + return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0 + + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1 + return is_last_phase and is_last_microbatch + + def _hook_before_step(self): + pass + + # install the main loop to wait for next batch input + def _wait_for_reset(self): + with self.reset_condition: + self.reset_condition.wait_for(lambda: self.reset) + self.reset = False + + # do the main loop to consume ready_list + def _work_loop(self): + # for init + self._get_producer_consumer() + torch.cuda.set_device(ppg.get_local_pp_rank()) + + # main loop + while True: + work_item_key = self._get_work_item_key() + # move current work item to output_list to activate subscribe in advance + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) + work_item = self.work_list[work_item_key] + + with self.output_list_condition_lock: + # assert work_item_key not in self.output_list + self.output_list[work_item_key] = work_item + self.output_list_condition_lock.notify_all() + + consume_result = self._consume_work_item_by_phase(work_item) + + with self.work_list_condition_lock: + self.work_list.pop(work_item_key) + work_item.output.set_result(consume_result) + + # if is last step in one batch reset context and do step + if self._is_last_step(work_item): + self._hook_before_step() + if hasattr(self, 'optimizer') and not work_item.forward_only: + self.step() + self._wait_for_reset() + + # reset context and resume loop + def reset_context(self): + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self._initialize_outstanding_range() + with self.work_list_condition_lock: + self.work_list.clear() + + with self.output_list_condition_lock: + self.output_list.clear() + + with self.reset_condition: + self.reset = True + self.reset_condition.notify_all() + + def initialize_optimizer(self, optimizer_class: type, **kwargs): + # TODO(jiangziyue) it's temporary code to deal with empty module partition. + # After tracer fixed, remove this part. + if len(list(self.module_partition.parameters())) > 0: + self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) + self.step_lock = threading.Lock() + self.step_lock.acquire() + + def wait_for_step(self): + self.step_lock.acquire() + + def step(self): + # TODO(jiangziyue) it's temporary code to deal with empty module partition. + # After tracer fixed, remove this part. + if len(list(self.module_partition.parameters())) > 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.step_lock.release() + + +class PipelineEngineBase(ABC, nn.Module): + + def __init__(self, + worker_type, + partition_fn: Callable, + stage_num, + num_microbatches, + device: str, + use_1F1B=False, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + super().__init__() + self.worker_type = worker_type + self.partition_fn: Callable = partition_fn + self.chunk = chunk + self.criterion = criterion + self.metric = metric + self.num_microbatches = num_microbatches + self.device = device + self.use_1F1B = use_1F1B + self.stage_num = stage_num + self.checkpoint = checkpoint + self.data_process_func = data_process_func + + self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() + + self.step_futs: List[Future] = [] + + self._check_argument() + self._create_pp_rank_to_rpc_worker_id() + self._create_pp_rank_to_module_partition_id() + self._init_worker() + + def _check_argument(self) -> None: + # make virtual stage num + self.virtual_stage_num = self.stage_num * self.chunk + assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" + + # check data_process_func + data_process_func = self.data_process_func + if data_process_func is not None: + assert callable(data_process_func), "data_process_func must be a function" + assert '' not in data_process_func.__repr__(), "data_process_func must be a global function" + assert '' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" + sig = inspect.signature(data_process_func) + assert len( + sig.parameters + ) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" + + def _get_actual_stage_num(self) -> int: + return self.stage_num if self.chunk == 1 else self.virtual_stage_num + + def _create_pp_rank_to_rpc_worker_id(self) -> None: + """create a map from model partition to stage_id, which is useful when use_interleave is True. + e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then + pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part + of partitions will be moved to device 0 and the others to device 1 + """ + stage_num = self.stage_num + actual_stage_num = self._get_actual_stage_num() + self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num + for pp_rank in range(actual_stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num + + def _create_pp_rank_to_module_partition_id(self) -> None: + """By default(both fill drain and 1F1B), length of model partitions equal to + actual_stage_num, so allocate model partition to corresponding stage + """ + actual_stage_num = self._get_actual_stage_num() + self.pp_rank_to_module_partition_id = [0] * actual_stage_num + for pp_rank in range(actual_stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + + def _init_worker(self) -> None: + actual_stage_num = self._get_actual_stage_num() + + worker_type = self.worker_type + checkpoint = self.checkpoint + num_microbatches = self.num_microbatches + device = self.device + criterion = self.criterion + metric = self.metric + partition_fn = self.partition_fn + chunk = self.chunk + data_process_func = self.data_process_func + + for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): + partition_id = self.pp_rank_to_module_partition_id[pp_rank] + partition_args = (partition_id, chunk, actual_stage_num) + rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] + if device[:4] == 'cuda': + device = f'cuda:{rpc_worker_id}' + self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, + worker_type, + args=(partition_fn, partition_args, pp_rank, + actual_stage_num, num_microbatches, device, + criterion, metric, checkpoint, data_process_func)) + + # let each worker know global worker rref (include itself) + sync_futs = [] + for pp_rank in self.pp_rank_to_worker_rref: + fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + sync_futs.append(fut) + + for fut in sync_futs: + fut.wait() + + def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: + parameters = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + parameters[stage_id] = [] + worker_rref = self.pp_rank_to_worker_rref[stage_id] + for p in worker_rref.rpc_sync().get_parameters(): + parameters[stage_id].append(p) + return parameters + + def remote_grad(self) -> Dict[int, List[torch.Tensor]]: + grads = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + grads[stage_id] = [] + worker_rref = self.pp_rank_to_worker_rref[stage_id] + for grad in worker_rref.rpc_sync().get_parameter_gradients(): + grads[stage_id].append(grad) + return grads + + def get_input_pp_ranks(self) -> List[int]: + return [0] + + def get_output_pp_ranks(self) -> List[int]: + return [self._get_actual_stage_num() - 1] + + def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], + output_pp_ranks: List[int], ret_future): + actual_stage_num = self._get_actual_stage_num() + use_1F1B = self.use_1F1B + if microbatch_id >= actual_stage_num: + if forward_only or not use_1F1B: + for pp_rank in output_pp_ranks: + ret_future[pp_rank][microbatch_id - actual_stage_num].wait() + else: + key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.rpc_sync().get_output_by_key(key, ref_use=True) + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + # TODO : add relationship between input_pp_ranks and parts of microbatch + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + # TODO : add relationship between output_pp_ranks and parts of microlabels + worker_rref.remote().set_labels(microbatch_id, microlabels) + + # TODO(jiangziyue) : get model output with single value, instead of merging into last stage. + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + if not forward_only: + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) + worker_rref.rpc_sync().get_output_by_key(key) + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + ret = ret_future[pp_rank][microbatch_id].wait() + # TODO : more stable format + ret = [ret] if isinstance(ret, torch.Tensor) else ret + worker_forward_result[microbatch_id] = ret + + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result + + def _reset_worker(self): + actual_stage_num = self._get_actual_stage_num() + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().reset_context() + self.step_futs.append(fut) + + for fut in self.step_futs: + fut.wait() + + def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): + batch_lengths = get_batch_lengths(batch) + batch_length = batch_lengths[0] + + if labels is not None and not forward_only: + assert hasattr( + self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" + + num_microbatches = self.num_microbatches + + assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal" + microbatch_size = math.ceil(batch_length / num_microbatches) + device = self.device + + # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks' + input_pp_ranks = self.get_input_pp_ranks() + output_pp_ranks = self.get_output_pp_ranks() + + # a cache to collect data and control flow + ret_future = self._create_ret_future(output_pp_ranks) + + for microbatch_id in range(num_microbatches): + # control data input speed + # to prevent exceed of wait limitations + self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) + batch_start = microbatch_size * microbatch_id + batch_end = min(batch_start + microbatch_size, batch_length) + + # set input + microbatch = split_batch(batch, batch_start, batch_end, device) + self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only) + + # set labels + if labels is not None: + # microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] + microlabels = split_batch(labels, batch_start, batch_end, device) + self._set_labels(output_pp_ranks, microbatch_id, microlabels) + + # get data asynchronously + self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future) + + # wait for first rank to ensure all backwards are done + self._ensure_backward(forward_only, input_pp_ranks) + + # collect forward result + forward_result = self._collect_forward_result(output_pp_ranks, ret_future) + + if not forward_only and hasattr(self, 'optimizer_class'): + # wait for all step + for pp_rank in self.pp_rank_to_worker_rref: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.rpc_sync().wait_for_step() + + self._reset_worker() # reset worker attributes for next batch + return forward_result + + def initialize_optimizer(self, optimizer_class: type, **kwargs): + self.optimizer_class = optimizer_class + for pp_rank in self.pp_rank_to_worker_rref: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs) + + def step(self): + actual_stage_num = self._get_actual_stage_num() + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().step() + self.step_futs.append(fut) + + for fut in self.step_futs: + fut.wait() diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..555955583def62d937c4bddf327434351f0919f7 --- /dev/null +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -0,0 +1,348 @@ +import threading +from typing import Callable, Dict, List + +import torch +import torch.distributed as dist +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem) +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future + +# Implementation of different Pipeline schedule +# Worker defines the worker for each stage +# PipelineEngine is the class for use + + +class FillDrainWorker(WorkerBase): + + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + num_microbatches = self.num_microbatches + + if self.forward_times < num_microbatches: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + target_key = UniqueKey(target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + + return target_key + + +class FillDrainPipelineEngine(PipelineEngineBase): + + def __init__(self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + + if chunk > 1: + assert num_microbatches % stage_num == 0, \ + "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + use_1F1B = False + + super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint, data_process_func) + + +class OneFOneBWorker(WorkerBase): + + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + pp_rank = self.pp_rank + actual_stage_num = self.actual_stage_num + num_microbatches = self.num_microbatches + is_last_stage = pp_rank == actual_stage_num - 1 + + if self.outstanding <= self.outstanding_range[0]: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + elif self.outstanding >= self.outstanding_range[1]: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: + raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") + + target_key = UniqueKey(target_microbatch_id, target_phase) + + # change outstanding_range at: + # 1. forward times reach actual_stage_num, this is the end of continuous forward + # 2. forward times reach num_microbatches, this is the end of 1F1B mode + if not is_last_stage and \ + target_key.phase == Phase.FORWARD: + if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: + # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 + outstanding_min = actual_stage_num - pp_rank - 1 + outstanding_max = actual_stage_num - pp_rank + self.outstanding_range = (outstanding_min, outstanding_max) + elif target_key.microbatch_id == num_microbatches - 1: + self.outstanding_range = (0, 0) + + return target_key + + +class OneFOneBPipelineEngine(PipelineEngineBase): + + def __init__(self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + + if chunk > 1: + assert num_microbatches % stage_num == 0, \ + "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" + use_1F1B = True + + super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint, data_process_func) + + +class ChimeraWorker(WorkerBase): + + def _get_producer_consumer(self) -> None: + rank = self.pp_rank + min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num + max_pp_rank = min_pp_rank + self.actual_stage_num - 1 + + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be aranged in order, the order of the input of current forward + self.producer_stage_ids = [] + self.consumer_stage_ids = [] + + # Just for demo + prev_rank = rank - 1 + next_rank = rank + 1 + if prev_rank >= min_pp_rank: + self.producer_stage_ids.append(prev_rank) + if next_rank <= max_pp_rank: + self.consumer_stage_ids.append(next_rank) + + def _get_work_item_key(self) -> UniqueKey: + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + real_microbatch_num = self.num_microbatches // 2 + + forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num + forward_block_num = self.forward_times // forward_block_size + + if self.forward_times >= real_microbatch_num or \ + ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: # others + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + + # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) + # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) + real_target_microbatch_id = target_microbatch_id * 2 + if pp_rank >= stage_num: + real_target_microbatch_id += 1 + target_key = UniqueKey(real_target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + return target_key + + def _initialize_partition(self): + # In order to ensure the down pipeline share the same parameter + # with the up pipeline, partition of down partition will be copied + # from corresponding up stage + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + device = self.device + if pp_rank < stage_num: + super()._initialize_partition() + else: + # if it is down pipeline, create partition by origin method + co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num] + # get the coresponding model state dict and wait for its init + state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict() + super()._initialize_partition() + self.module_partition.load_state_dict(state_dict) + + # init group for chimera in ppg + ppg.get_chimera_all_reduce_group(pp_rank) + + # lock for step sync + self.step_sync_lock = threading.Lock() + self.step_sync_lock.acquire() + + self.have_grad_lock = threading.Lock() + self.have_grad_lock.acquire() + + def _get_lock_gradient(self): + self.have_grad_lock.acquire() + grads = self.get_parameter_gradients() + self.step_sync_lock.release() + return grads + + def is_first_stage(self): + return (self.pp_rank % self.actual_stage_num) == 0 + + def is_last_stage(self): + return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 + + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + last_microbatch_id = self.num_microbatches - 1 + if self.pp_rank < self.actual_stage_num: + last_microbatch_id -= 1 + is_last_microbatch = work_item.microbatch_id == last_microbatch_id + return is_last_phase and is_last_microbatch + + def _get_step_order(self) -> List[int]: + # TODO : If you want to extend it to multi head chimera, overwrite here + stage_num = self.actual_stage_num + pp_rank = self.pp_rank + # pp_rank in the same device + local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1] + local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2) + return local_device_pp_ranks + + def _hook_before_step(self): + self.have_grad_lock.release() + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + co_pp_rank = (pp_rank + stage_num) % (2 * stage_num) + + # if currrent pp_rank is not the first to do step + # wait its previous pp_rank finish step + grads = self.get_parameter_gradients() + + # send + co_worker = self.pp_rank_to_worker_rref[co_pp_rank] + co_grads = co_worker.rpc_sync()._get_lock_gradient() + # sync + self.step_sync_lock.acquire() + for i in range(len(grads)): + grads[i] += co_grads[i] + + +class ChimeraPipelineEngine(PipelineEngineBase): + + def __init__(self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + + assert num_microbatches % stage_num == 0, \ + "In Chimera, num_microbatches must be the multiply of stage_num!" + use_1F1B = False + chunk = 1 + + super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint, data_process_func) + + def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], + output_pp_ranks: List[int], ret_future): + pass + + def _create_pp_rank_to_rpc_worker_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank + self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1 + + def _create_pp_rank_to_module_partition_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_module_partition_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + stage_num = self.stage_num + up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks} + # merge up and down + return {**up_ret_future, **down_ret_future} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_labels(microbatch_id, microlabels) + + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + stage_num = self.stage_num + num_microbatches = self.num_microbatches + if not forward_only: + for pp_rank in input_pp_ranks: + up_last_microbatch_id = num_microbatches - 2 + down_last_microbatch_id = num_microbatches - 1 + + up_worker_rref = self.pp_rank_to_worker_rref[pp_rank] + down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num] + + up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) + down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) + up_worker_rref.rpc_sync().get_output_by_key(up_key) + down_worker_rref.rpc_sync().get_output_by_key(down_key) + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]): + """Logic of collection of forward in Chimera. + Currently, only one input one output model is supported + """ + stage_num = self.stage_num + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + offset = (microbatch_id % 2) * stage_num + ret = ret_future[pp_rank + offset][microbatch_id].wait() + ret = [ret] if isinstance(ret, torch.Tensor) else ret + worker_forward_result[microbatch_id] = ret + + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77d601173b133c6418885f7f19beb3d19a14b14d --- /dev/null +++ b/colossalai/pipeline/rpc/utils.py @@ -0,0 +1,140 @@ +import argparse +import os +import warnings +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +import torch +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from colossalai.initialize import launch +from colossalai.pipeline.pipeline_process_group import ppg +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.futures import Future + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj + + +def tensor_shape_list(obj): + return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor) + + +def get_batch_lengths(batch): + lengths = [] + pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor) + return lengths + + +def split_batch(batch: Any, start, stop, device: str): + if device == 'cuda': + fn = lambda x: x[start:stop].cuda() + else: + fn = lambda x: x[start:stop] + return pytree_map(batch, fn=fn, process_types=torch.Tensor) + + +def type_detail(obj): + return pytree_map(obj, lambda x: type(x), map_all=True) + +def pytree_filter(fn, obj, process_types): + if obj is None: + return None + + filters = [] + + def condition_append(obj): + if fn(obj): + filters.append(obj) + + pytree_map(obj, fn=condition_append, process_types=process_types) + return filters + + +def get_real_args_kwargs(args_or_kwargs): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + # TODO : combine producer and consumer + # by default, merge all args in the output args or kwargs + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + + return args_or_kwargs + + +def run_worker(rank, args, master_func): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + ppg.args = args + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if _is_current_rpc_agent_set(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--epoch', type=int, default=1) + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args() diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df7226644a7a9931ecd82c09a39653fc5e0cfdeb --- /dev/null +++ b/colossalai/pipeline/utils.py @@ -0,0 +1,275 @@ +import heapq +import inspect +import torch + +from colossalai.logging import get_dist_logger +from colossalai.nn.layer.utils import CheckpointModule +from typing import List + +from collections import OrderedDict + +def _binary_partition(weights: List, start: int, end: int): + """Returns the binary partition position of `weights`, given the start + position `st` and the end position `ed`. + + Args: + weights (list): A python list to be binary partitioned + start (int): the start position of the binary partition + end (int): the end position of the binary partition + + Returns: + int: the binary partition position of `weights` + """ + w_sum = weights[end - 1] + prefix = 0 + if start > 0: + w_sum -= weights[start - 1] + prefix = weights[start - 1] + minimum = float("inf") + for idx in range(start + 1, end): + front = weights[idx - 1] - prefix + diff = abs(w_sum - 2 * front) + if diff < minimum: + pos = idx + minimum = diff + + return start, pos, end + + +def _heap_addition(weights: List, intervals: int, add_cnt: int): + """ + """ + + def _heap_push(heap, st, ed): + value = weights[ed - 1] + if st > 0: + value -= weights[st - 1] + heapq.heappush(heap, (-value, st, ed)) + + ret_intervals = [] + heap = [] + + for st, ed in intervals: + _heap_push(heap, st, ed) + + while add_cnt > 0: + _, st, ed = heapq.heappop(heap) + if ed - st == 1: + ret_intervals.append((st, ed)) + else: + l, m, r = _binary_partition(weights, st, ed) + _heap_push(heap, l, m) + _heap_push(heap, m, r) + add_cnt -= 1 + + while heap: + _, st, ed = heapq.heappop(heap) + ret_intervals.append((st, ed)) + + ret_intervals.sort() + return ret_intervals + + +def _calc_partitions(weights, value): + prev = 0 + prefix = 0 + num_block = 0 + intervals = [] + + for idx, w in enumerate(weights): + if weights[idx] - prefix > value: + intervals.append((prev, idx)) + prev = idx + prefix = weights[idx - 1] + num_block += 1 + + intervals.append((prev, len(weights))) + return num_block + 1, intervals + + +def _binary_search(weights, num): + length = len(weights) + prefix = [1 if w == 0 else w for w in weights] + for i in range(1, length): + prefix[i] += prefix[i - 1] + + lower_bound = max(weights) + upper_bound = prefix[length - 1] + + while upper_bound > lower_bound: + mid = (upper_bound + lower_bound) // 2 + number, _ = _calc_partitions(prefix, mid) + if number <= num: + upper_bound = mid + else: + lower_bound = mid + 1 + + num_block, intervals = _calc_partitions(prefix, upper_bound) + if num_block < num: + intervals = _heap_addition(prefix, intervals, num - num_block) + + return intervals + + +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): + assert num_items % num_chunks == 0, \ + "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + logger = get_dist_logger() + parts = [[] for _ in range(pipeline_parallel_size)] + partition_items = num_items // num_chunks + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + logger.warning("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + return parts + + +def partition_balanced(weights, pipeline_parallel_size, num_chunks): + num_total = pipeline_parallel_size * num_chunks + num_items = len(weights) + if num_items <= num_total: + return partition_uniform(num_items, pipeline_parallel_size, num_chunks) + + intervals = _binary_search(weights, num_total) + + current = 0 + parts = [[] for _ in range(pipeline_parallel_size)] + for inter in intervals: + parts[current].append(inter) + current = (current + 1) % pipeline_parallel_size + + return parts + + +def build_kwargs_for_module(function, input_tensor, kw_dict): + """ + Generally, the first argument of module.forward is an input tensor come from the previous layer. + Therefore, we just filter the kwargs from second element of the dictionary. + """ + sig = inspect.signature(function) + if input_tensor is None: + kwargs_offset = 0 + elif isinstance(input_tensor, torch.Tensor): + kwargs_offset = 1 + elif isinstance(input_tensor, (tuple, OrderedDict)): + #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # Huggingface will take their own structures based on OrderedDict as the output + # between layers so we've to close this check. + kwargs_offset = len(input_tensor) + args_name_list = list(sig.parameters.keys()) + kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} + if len(kw_dict) == 0: + return None + return kw_dict + + +def build_kwargs_for_function(function, kw_dict): + sig = inspect.signature(function) + kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} + if len(kw_dict) == 0: + return None + return kw_dict + + +def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): + """ + We suppose the callable object passed to to_layer_list method in two purpose: + a. use the callable object to modify input tensor, such as \ + lambda x: torch.flatten(x, 1) + b. use the callable object to modify kwargs value, such as \ + def foo(attention_mask=None): + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + return attention_mask + """ + + if kw_dict is not None: + rst = func(**kw_dict) + if isinstance(rst, tuple): + for i, k in enumerate(kw_dict.keys()): + kwargs[k] = rst[i] + else: + for k in kw_dict.keys(): + kwargs[k] = rst + return input_tensor + if isinstance(input_tensor, tuple): + assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.' + sig = inspect.signature(func) + func_args_num = len(sig.parameters) + assert func_args_num <= len( + input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.' + if func_args_num < len(input_tensor): + return func(*input_tensor[:func_args_num]) + else: + return func(*input_tensor) + assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.' + return func(input_tensor) + + +def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): + + assert func_key in func_dict, f"{func_key} is not in the function_dict." + funcs_to_exec = func_dict[func_key] + if isinstance(funcs_to_exec, list): + for f in funcs_to_exec: + f_kwargs = build_kwargs_for_function(f, kwargs) + input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) + else: + f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs) + input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) + + return input_tensor + + +def call_module(module, args=None, kwargs=None): + if args is None: + args = () + if kwargs is None: + kwargs = {} + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + sig = inspect.signature(forward_func) + param_nums = len(sig.parameters) + feed_nums = len(args) + len(kwargs) + args_needed_nums = param_nums - len(kwargs) + args_needed = args[:args_needed_nums] + if isinstance(module, CheckpointModule): + convert_kwargs_to_args = [] + for v in kwargs.values(): + convert_kwargs_to_args.append(v) + return module(*args_needed, *convert_kwargs_to_args) + else: + return module(*args_needed, **kwargs) + + +def customized_partition(exec_seq): + ''' + This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an + annotation to note the partition point. + ''' + customized_parts = {} + start = 0 + stop = 0 + rank = 0 + for element in exec_seq: + if isinstance(element, str): + if element == 'SPLIT_NODE': + customized_parts[rank] = [(start, stop)] + start = stop + rank += 1 + else: + stop += 1 + customized_parts[rank] = [(start, stop)] + return customized_parts diff --git a/colossalai/registry/__init__.py b/colossalai/registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b48d9e7d21d2eed6f9e9dc54e26191aa3632a52c --- /dev/null +++ b/colossalai/registry/__init__.py @@ -0,0 +1,19 @@ +import torch.distributed.optim as dist_optim +import torch.nn as nn +import torch.optim as optim + +from .registry import Registry + +LAYERS = Registry("layers", third_party_library=[nn]) +MODELS = Registry("models") +OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim]) +DATASETS = Registry("datasets") +DIST_GROUP_INITIALIZER = Registry("dist_group_initializer") +GRADIENT_HANDLER = Registry("gradient_handler") +LOSSES = Registry("losses", third_party_library=[nn]) +HOOKS = Registry("hooks") +TRANSFORMS = Registry("transforms") +DATA_SAMPLERS = Registry("data_samplers") +LR_SCHEDULERS = Registry("lr_schedulers") +SCHEDULE = Registry("schedules") +OPHOOKS = Registry("ophooks") diff --git a/colossalai/registry/registry.py b/colossalai/registry/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4173f7ab992079d180322245d25cbb9010b07c --- /dev/null +++ b/colossalai/registry/registry.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from types import ModuleType +from typing import List + + +class Registry: + """This is a registry class used to register classes and modules so that a universal + object builder can be enabled. + + Args: + name (str): The name of the registry . + third_party_library (list, optional): + List of third party libraries which are used in the initialization of the register module. + """ + + def __init__(self, name: str, third_party_library: List[ModuleType] = None): + self._name = name + self._registry = dict() + self._third_party_lib = third_party_library + + @property + def name(self): + return self._name + + def register_module(self, module_class): + """Registers a module represented in `module_class`. + + Args: + module_class (class): The module to be registered. + Returns: + class: The module to be registered, so as to use it normally if via importing. + Raises: + AssertionError: Raises an AssertionError if the module has already been registered before. + """ + module_name = module_class.__name__ + assert module_name not in self._registry, f"{module_name} not found in {self.name}" + self._registry[module_name] = module_class + + # return so as to use it normally if via importing + return module_class + + def get_module(self, module_name: str): + """Retrieves a module with name `module_name` and returns the module if it has + already been registered before. + + Args: + module_name (str): The name of the module to be retrieved. + Returns: + :class:`object`: The retrieved module or None. + Raises: + NameError: Raises a NameError if the module to be retrieved has neither been + registered directly nor as third party modules before. + """ + if module_name in self._registry: + return self._registry[module_name] + elif self._third_party_lib is not None: + for lib in self._third_party_lib: + if hasattr(lib, module_name): + return getattr(lib, module_name) + raise NameError(f'Module {module_name} not found in the registry {self.name}') + + def has(self, module_name: str): + """Searches for a module with name `module_name` and returns a boolean value indicating + whether the module has been registered directly or as third party modules before. + + Args: + module_name (str): The name of the module to be searched for. + Returns: + bool: A boolean value indicating whether the module has been registered directly or + as third party modules before. + """ + found_flag = module_name in self._registry + + if self._third_party_lib: + for lib in self._third_party_lib: + if hasattr(lib, module_name): + found_flag = True + break + + return found_flag diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2da64e6c33a0f410dd5b7a9e05ff5775cc0a6eb --- /dev/null +++ b/colossalai/tensor/__init__.py @@ -0,0 +1,18 @@ +from . import distspec +from .colo_parameter import ColoParameter +from .colo_tensor import ColoTensor +from .comm_spec import CollectiveCommPattern, CommSpec +from .compute_spec import ComputePattern, ComputeSpec +from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec +from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager +from .process_group import ProcessGroup +from .tensor_spec import ColoTensorSpec +from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor + +__all__ = [ + 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', + 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', + 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', + 'merge_same_dim_mesh_list' +] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4c8ce69df6c8a211847078d3f2c15206e8ed50 --- /dev/null +++ b/colossalai/tensor/colo_parameter.py @@ -0,0 +1,102 @@ +from typing import Optional + +import torch + +from colossalai.tensor.colo_tensor import ColoTensor +from colossalai.tensor.const import TensorType +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.tensor.tensor_spec import ColoTensorSpec + + +def filter_args(func, *args): + return [arg for arg in args if func(arg)] + + +def replace_args(args, kwargs, new_args): + args = new_args[:len(args)] + for k, v in zip(kwargs.keys(), new_args[len(args):]): + kwargs[k] = v + return tuple(args), kwargs + + +class ColoParameter(ColoTensor, torch.nn.Parameter): + r"""A kind of ColoTensor to be considered as a module parameter. + + """ + + def __new__(cls, + data: Optional[torch.Tensor] = None, + requires_grad: bool = True, + spec: ColoTensorSpec = None) -> 'ColoParameter': + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def __init__(self, + data: Optional[torch.Tensor] = None, + requires_grad: bool = True, + spec: ColoTensorSpec = None) -> None: + ColoTensor.__init__(self, data, spec) + self._type = TensorType.MODEL + # a list contains modules sharing this ColoParameter with others. + self._shared_param_modules = [] + + @property + def shared_param_modules(self): + return self._shared_param_modules + + @staticmethod + def from_torch_tensor(tensor: torch.Tensor, + requires_grad: bool = True, + spec: ColoTensorSpec = None) -> 'ColoParameter': + tensor = tensor.as_subclass(ColoParameter) + tensor.__init__(tensor, requires_grad=requires_grad, spec=spec) + return tensor + + def __repr__(self): + return f'ColoParameter: {ColoTensor.__repr__(self)}' + + @classmethod + def __torch_function__(cls, func, types, args=..., kwargs=None): + if ColoParamOpHookManager.has_hook(): + if not func.__name__.startswith('__'): + if kwargs is None: + kwargs = {} + params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) + if len(params) > 0: + with torch._C.DisableTorchFunction(): + new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) + args, kwargs = replace_args(args, kwargs, new_args) + ret = super().__torch_function__(func, types, args, kwargs) + with torch._C.DisableTorchFunction(): + ret = ColoParamOpHookManager.post_op(params, ret) + return ret + return super().__torch_function__(func, types, args, kwargs) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + with torch._C.DisableTorchFunction(): + data = self.data.clone() + tensor = ColoParameter(data, + self.requires_grad, + spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec)) + memo[id(self)] = tensor + return tensor + + def __reduce_ex__(self, proto): + # Adapted from torch._utils._rebuild_parameter + # def _rebuild_colo_parameter(data, requires_grad, backward_hooks): + # colo_param = ColoParameter(data, requires_grad) + # colo_param._backward_hooks = backward_hooks + # return colo_param + + # return ( + # _rebuild_colo_parameter, + # (self.data, self.requires_grad, OrderedDict()) + # ) + + # TODO(jzy) we don't support object reflection now. + # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. + raise NotImplementedError diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..7ecb407b5ef00f60645b9f0c5cf884b974ef712b --- /dev/null +++ b/colossalai/tensor/colo_tensor.py @@ -0,0 +1,322 @@ +from copy import copy +from functools import lru_cache +from typing import Callable, Optional, Set + +import torch + +from colossalai.tensor.dist_spec_mgr import DistSpecManager +from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec +from colossalai.tensor.process_group import ProcessGroup +from colossalai.tensor.tensor_spec import ColoTensorSpec + +from .const import TensorType +from .op_wrapper import _COLOSSAL_OPS + + +@lru_cache(None) +def _get_my_nowrap_functions() -> Set[Callable]: + Tensor = torch.Tensor + return { + Tensor._base.__get__, + Tensor.grad.__get__, + Tensor._grad.__get__, + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + } + + +def _convert_output(output, colo_spec: ColoTensorSpec): + if type(output) == torch.Tensor: + return ColoTensor.from_torch_tensor(output, colo_spec) + elif isinstance(output, (list, tuple)): + return type(output)(_convert_output(o, colo_spec) for o in output) + else: + return output + + +def _get_spec_from_args(args, kwargs) -> ColoTensorSpec: + for elem in args: + if isinstance(elem, ColoTensor): + pg = elem.get_process_group() + dp = elem.dist_spec + return ColoTensorSpec(pg, dp) + elif isinstance(elem, (list, tuple)): + spec = _get_spec_from_args(elem, {}) + if spec is not None: + return spec + for k, v in kwargs.items(): + if isinstance(v, ColoTensor): + pg = v.get_process_group() + dp = v.dist_spec + return ColoTensorSpec(pg, dp) + return None + + +class ColoTensor(torch.Tensor): + """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. + + The Colotensor can be initialized with a PyTorch tensor in the following ways. + + >>> pg = ProcessGroup() + >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()) + >>> # The tensor passed in is a tensor after sharding but not a global tensor. + >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), + >>> dims=[0], + >>> num_partitions=[world_size]) + >>> tensor_spec = ColoTensorSpec(pg, shard_spec) + >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) + + Args: + data (torch.Tensor): a torch tensor used as the payload the colotensor. + spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). + """ + torch_minor = int(torch.__version__.split('.')[1]) + + def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': + """ + The signature of the __new__ has to be consistent with the torch.Tensor. + + Args: + data (torch.Tensor): a torch tensor used as the payload the colotensor. + spec (TensorSpec, optional): the tensor spec of initialization. + + Returns: + ColoTensor: a ColoTensor wrappers the data. + """ + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, data.requires_grad) + + def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None: + # If not set spec, use a DP process group and replicate dist spec + if spec is None: + self.has_initialized = False + self.dist_spec = ReplicaSpec() + self.compute_spec = None + self.process_group = ProcessGroup() + else: + self.has_initialized = True + self.dist_spec = spec.dist_attr + self.compute_spec = spec.compute_attr + if spec.pg is None: + self.process_group = ProcessGroup() + else: + self.process_group = spec.pg + + self._type = TensorType.NONMODEL + + def has_compute_spec(self) -> bool: + return self.compute_spec is not None + + def is_model_data(self) -> bool: + return self._type == TensorType.MODEL + + def get_process_group(self) -> 'ProcessGroup': + return self.process_group + + def set_process_group(self, pg: ProcessGroup): + """set_process_group + change the pg of the ColoTensor. Note that the valid use cases is limited. + Only existing pg is DP and dist spec is REPLICaTE is valid. + + Args: + pg (ProcessGroup): target pg + + """ + assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid" + # if the new pg is the same as the old pg, just returns + if self.process_group == pg: + return + assert self.process_group.tp_world_size() == 1, \ + "Can not set_process_group on a ColoTensor whose process_group has tp world group" + assert self.dist_spec.placement.value == 'r', \ + "Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE" + + self.process_group = pg + + def get_tp_world_size(self) -> int: + return self.process_group.tp_world_size() + + def set_dist_spec(self, dist_spec: _DistSpec): + """set_dist_spec + set dist spec and change the payloads. + + Args: + dist_spec (_DistSpec): target dist spec. + """ + assert isinstance(dist_spec, _DistSpec) + assert self.process_group is not None + self._redistribute(dist_spec) + + def set_tensor_spec(self, dist_spec, compute_spec): + if dist_spec is not None: + assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}" + self.set_dist_spec(dist_spec) + if compute_spec is not None: + self.compute_spec = compute_spec + + def has_compute_pattern(self, compute_pattern): + return self.compute_spec.compute_pattern == compute_pattern + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + global _COLOSSAL_OPS + if func in _COLOSSAL_OPS: + func = _COLOSSAL_OPS[func] + + if cls.torch_minor >= 12: + # in order to trigger pre-op hook in the forward of checkpoint module + # we have to capture the `backward` function + # and make sure that it does not in `torch._C.DisableTorchFunction()` context + if func is torch.Tensor.backward: + assert len(args) == 1 # only has 1 paramter + backward_tensor = torch.Tensor(args[0]) + tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} + return backward_tensor.backward(**tensor_kwargs) + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + if func in _get_my_nowrap_functions(): + return ret + else: + colo_spec = _get_spec_from_args(args, kwargs) + return _convert_output(ret, colo_spec) + + def __repr__(self): + return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' + + def _redistribute(self, dist_spec: _DistSpec) -> None: + """_redistribute + Note the function will not handle the logic of backward propagation! + It is used during model tensor initializations as an internal function. + + Args: + dist_spec (_DistSpec): the target dist. spec. + """ + assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted" + with DistSpecManager.no_grad(): + self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group) + self.dist_spec = dist_spec + + def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': + """redistribute + Redistribute the tensor among processes. The rule is like this: + + 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the + DP process group not changed. + + 2. If the pg is not not None and not equal to the current process group. + First, convert the tensor as replicated among the TP process group. + Second, reset the process group to the new pg. + Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec. + + Args: + dist_spec (_DistSpec): the new dist spec. + pg (Optional[ProcessGroup], optional): the new process group . Defaults to None. + + Returns: + ColoTensor: a redistributed colotensor + """ + if pg is not None and pg != self.get_process_group(): + # if the pg is not equal, convert the current tensor to replicated + handled = self.redistribute(ReplicaSpec()) + else: + handled = self + pg = self.process_group + + ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg) + return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec)) + + def to_replicate_(self): + """to_replicate_ + + an inline member function, converting dist spec of the tensor to REPLICATE + """ + self._redistribute(dist_spec=ReplicaSpec()) + + def to_replicate(self) -> 'ColoTensor': + """to_replicate + + converting dist spec of the tensor to ReplicaSpec() + """ + return self.redistribute(ReplicaSpec()) + + @staticmethod + def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': + """from_torch_tensor + + A static method builds a `ColoTensor` from a PyTorch Tensor. + + Args: + tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor. + spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None. + + Returns: + ColoTensor: a ColoTensor + """ + tensor = tensor.as_subclass(ColoTensor) + tensor.__init__(tensor, spec=spec) + return tensor + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + with torch._C.DisableTorchFunction(): + data = self.data.clone() + tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec))) + memo[id(self)] = tensor + return tensor + + # override builtin functions which must use tensor in replicate placement # + + def size_local(self, *args) -> torch.Size: + with torch._C.DisableTorchFunction(): + return super().size(*args) + + def size_global(self, *args) -> torch.Size: + """size_global + + override the torch buildin size() + the shape passed in must be in a replicate placement. + + Returns: + torch.Size: the global tensor shape + """ + if self.is_replicate(): + return self.size_local(*args) + spec = self.dist_spec + dims = spec.dims + num_partitions = spec.num_partitions + # import inspect + # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) + size_list = list(self.size_local()) + for dim, num_partition in zip(dims, num_partitions): + size_list[dim] *= num_partition + if args == (): + return torch.Size(size_list) + else: + return size_list[args[0]] + + # Some API for dist spec check + + def is_replicate(self): + return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ + or (len(self.dist_spec.num_partitions) == 1 + and self.dist_spec.num_partitions[0] == 1) \ + or (self.process_group.tp_world_size() == 1) + + def is_shard_1dcol(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 + + def is_shard_1drow(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 + + def is_sharded(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..3c9e0fd566967e90f1ceda44acb11a209a9b783c --- /dev/null +++ b/colossalai/tensor/comm_spec.py @@ -0,0 +1,525 @@ +import operator +from enum import Enum +from functools import reduce + +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +__all__ = [ + 'CollectiveCommPattern', + 'CommSpec', +] + + +def _all_gather(tensor, comm_spec): + ''' + Implement all gather operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _split(tensor, comm_spec): + ''' + Implement shard operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output + + +def _all_to_all(tensor, comm_spec): + ''' + Implement all to all operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _all_reduce(tensor, comm_spec, async_op=False): + ''' + Implement all reduce operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor + + +def _mix_gather(tensor, comm_spec): + ''' + Implement mix gather operation on device mesh based on information provided by comm_spec. + Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is + different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather + only does all-gather in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + ShardingSpec => gather_dim, logical_process_axes + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (1, 1) + RS01 => [b], (1, 1) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + S0S1: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...] + mesh_shape = (2,4) + cat_slice = [4,2] + tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b) + output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a) + S1S0: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)] + mesh_shape = (2,4) + cat_slice = [2,4] + tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b) + tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b) + tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b) + S10R: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + S01R: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + ''' + total_slices = comm_spec.device_mesh.mesh_shape[0] + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] + leading_group_dim = comm_spec.logical_process_axes[0] + assert len(comm_spec.device_mesh.process_groups_dict) == 1 + _, process_group = comm_spec.device_mesh.process_groups_dict[0][0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + + # Global all_gather + dist.all_gather(tensor_list, tensor, group=process_group) + + # This is very ugly. I'm figuring out more elegant methods + tensor_list_sorted = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices) + ] + for i in range(total_slices): + tensor_list_sorted[i] = tensor_list[process_number_list[i]] + tensor_list = tensor_list_sorted + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() + else: + mesh_shape = comm_spec.device_meshes.mesh_shape + cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + tmp_tensor_shape = list(tensor.shape) + tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] + tmp_tensor_shape = torch.Size(tmp_tensor_shape) + tmp_tensor_list = [ + torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) + ] + for i in range(cat_slice[1]): + tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), + comm_spec.gather_dim[0]).contiguous() + output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() + + return output + + +def _mix_split(tensor, comm_spec): + ''' + Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) + Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split + because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (0, 0) + RS01 => [b], (0, 0) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + mesh_shape = comm_spec.device_meshes.mesh_shape + dim = comm_spec.gather_dim + total_slices = comm_spec.device_mesh.mesh_shape[0] + + # Get global rank + rank = dist.get_rank() + + leading_group_dim = comm_spec.logical_process_axes[0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + rank = process_number_list.index(rank) + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + length = tensor.shape[dim[0]] // total_slices + start = length * rank + output = torch.narrow(tensor, dim[0], start, length).contiguous() + else: + tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]] + rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]] + start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]] + tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous() + output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + A customized communication operation which forward is an identity operation, + backward is all_reduce operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output, ctx.comm_spec), None + + +class _ReduceInput(torch.autograd.Function): + """ + A customized communication operation which forward is all_reduce operation, + backward is an identity operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + return _all_reduce(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + A customized communication operation which forward is split operation, + backward is an all gather operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _split(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output, ctx.comm_spec), None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """ + A customized communication operation which forward is an all gather operation, + backward is split operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _all_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.comm_spec), None + + +class _AllToAll(torch.autograd.Function): + """ + A customized communication operation which forward is an all to all operation, + backward is an all to all operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_to_all(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + output = _all_to_all(input_, comm_spec) + comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, + sharding_spec=comm_spec.sharding_spec, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis) + ctx.comm_spec = comm_spec_for_backward + return output + + @staticmethod + def backward(ctx, grad_outputs): + return _all_to_all(grad_outputs, ctx.comm_spec), None + + +class _MixGatherForwardMixSplitBackward(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_): + return _mix_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _mix_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _mix_split(grad_output, ctx.comm_spec), None + + +def reduce_grad(input_, comm_spec): + return _ReduceGrad.apply(input_, comm_spec) + + +def reduce_input(input_, comm_spec): + return _ReduceInput.apply(input_, comm_spec) + + +def split_forward_gather_backward(input_, comm_spec): + return _SplitForwardGatherBackward.apply(input_, comm_spec) + + +def gather_forward_split_backward(input_, comm_spec): + return _GatherForwardSplitBackward.apply(input_, comm_spec) + + +def all_to_all(input_, comm_spec): + return _AllToAll.apply(input_, comm_spec) + + +def mixgather_forward_split_backward(input_, comm_spec): + return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec) + + +class CollectiveCommPattern(Enum): + GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' + ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' + SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' + ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' + IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" + + +class CommSpec: + ''' + Communication spec is used to record the communication action. It has two main functions: + 1. Compute the communication cost which will be used in auto parallel solver. + 2. Convert the communication spec to real action which will be used in runtime. + It contains comm_pattern to determine the + communication method, sharding_spec to determine the communication size, gather_dim and shard_dim + to determine the buffer shape, and logical_process_axis + + Argument: + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. + gather_dim(int, Optional): The gather_dim of the tensor will be gathered. + shard_dim(int, Optional): The shard_dim of the tensor will be sharded. + logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. + ''' + + def __init__(self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + forward_only=False, + mix_gather=False): + self.comm_pattern = comm_pattern + self.sharding_spec = sharding_spec + self.gather_dim = gather_dim + self.shard_dim = shard_dim + self.logical_process_axis = logical_process_axis + self.forward_only = forward_only + if isinstance(self.logical_process_axis, list): + if not mix_gather: + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.logical_process_axis = 0 + else: + self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + # Create a new member `logical_process_axes` to distinguish from original flatten + self.logical_process_axes = logical_process_axis + else: + self.device_mesh = self.sharding_spec.device_mesh + + def __repr__(self): + res_list = ["CommSpec:("] + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis: {self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_asex:{self.logical_process_axes})") + + return ''.join(res_list) + + def get_comm_cost(self): + ''' + For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. + For shard operation, it is an on-chip operation, so the communication cost is zero. + ''' + comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) + cost_dict = {} + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + # give a tiny cost to shard + backward_communication_cost = 10 + + if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + # grad should have same shape as input tensor + # all to all operation has same logical process axis as forward. + backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + backward_communication_cost = 0 + + if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + forward_communication_cost = 0 + backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + # give a tiny cost to shard + forward_communication_cost = 10 + backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + # no need for axis because all devices are used in mix_gather + forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) + backward_communication_cost = 10 + + if self.forward_only: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = 0 + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + else: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = backward_communication_cost + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + + return cost_dict + + def covert_spec_to_action(self, tensor): + ''' + Convert CommSpec into runtime action, implement real collection communication to target tensor. + The collection communication action is directed by the CommSpec. + + Argument: + tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. + ''' + if self.comm_pattern in pattern_to_func_dict: + tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) + else: + tensor = tensor + return tensor + + +pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward, +} diff --git a/colossalai/tensor/compute_spec.py b/colossalai/tensor/compute_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..a9774c34c01bfb1b536c88afdd5c7806a8859edb --- /dev/null +++ b/colossalai/tensor/compute_spec.py @@ -0,0 +1,29 @@ +from enum import Enum + + +class ComputePattern(Enum): + TP1D = 0 + TP2D = 1 + TP2P5D = 2 + TP3D = 3 + + +class ComputeSpec(object): + """ComputeSpec + The Specification for compuattion pattern + + Args: + compute_pattern (ComputePattern): an Enum instance for compute pattern. + """ + + def __init__(self, compute_pattern: ComputePattern) -> None: + assert isinstance(compute_pattern, ComputePattern) + self.compute_pattern = compute_pattern + # Make sure output tensors are replicate + self.output_replicate = True + + def __repr__(self): + return f'Compute pattern: {self.compute_pattern}' + + def set_output_replicate(self, flag: bool = True): + self.output_replicate = flag diff --git a/colossalai/tensor/const.py b/colossalai/tensor/const.py new file mode 100644 index 0000000000000000000000000000000000000000..356e8ecc885a3fb24766683b106a91ca2fac44eb --- /dev/null +++ b/colossalai/tensor/const.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class TensorType(Enum): + MODEL = 0 + NONMODEL = 1 # mainly activations diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c0ce28e9fb6d2c912a4530970d81bfd2fba7ae --- /dev/null +++ b/colossalai/tensor/dist_spec_mgr.py @@ -0,0 +1,189 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +# from colossalai.nn.layer.utils import divide +from numpy import prod +from packaging import version + +from colossalai.logging import get_dist_logger +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor.process_group import ProcessGroup + + +# TODO(jiaruifang) circle import, move the divide to colossalai.commons. +# colossalai.tensor shall not import any submodule from colossal.nn +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, 'denominator can not be zero' + assert numerator % denominator == 0, \ + '{} is not divisible by {}'.format(numerator, denominator) + return numerator // denominator + + +class TransformDistSpec(torch.autograd.Function): + + @staticmethod + def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func): + ctx.old_dist_spec = old_dist_spec + ctx.dist_spec = dist_spec + ctx.backward_trans_func = backward_trans_func + ctx.pg = pg + return forward_trans_func(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def backward(ctx, grad_outputs): + return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, + ctx.pg), None, None, None, None, None + + +class DistSpecManager: + + _use_autograd_function: bool = True + + @staticmethod + def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: + pass + + @staticmethod + def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, + pg: ProcessGroup) -> torch.Tensor: + """_shard_as: shard the tensor w.r.t a distributed specification. + Assuming the tensor passed in is a global (replicated) tensor. + Args: + tensor (torch.Tensor): a global (replicated) tensor before shard + dist_spec (_DistSpec): the distributed spec. to be sharded as. + pg (ProcessGrouo): the process group of the corresponding colotensor + Returns: + torch.Tensor: a torch tensor after sharded. + """ + assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + + chunk = tensor + idx = pg.tp_local_rank() + num_parts = prod(dist_spec.num_partitions) + for i, dim in enumerate(dist_spec.dims): + num_parts //= dist_spec.num_partitions[i] + + chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i]) + chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size) + idx %= num_parts + return chunk.clone().detach().contiguous() + + @staticmethod + def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + """_gather gather sharded tensors to a replicated one. + Args: + tensor (torch.Tensor): a shared torch tensor + old_dist_spec (_DistSpec): the distributed spec. of the tensor. + + Returns: + torch.Tensor: a replicated tensor. + """ + assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" + is_cpu_tensor = False + if tensor.device.type == 'cpu': + # pytorch lower than 1.11 dose not support gather a cpu tensor. + # Therefore, we transfer tensor to GPU before gather. + saved_dev = tensor.device + tensor.data = tensor.data.cuda() + is_cpu_tensor = True + + buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] + assert tensor.device.type == 'cuda' + dist.all_gather(buffer, tensor, group=pg.tp_process_group()) + for i in range(len(old_dist_spec.dims) - 1, -1, -1): + new_buffer = [] + dim = old_dist_spec.dims[i] + num_parts = old_dist_spec.num_partitions[i] + for start in range(0, len(buffer), num_parts): + new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) + buffer = new_buffer + assert len(buffer) == 1 + + if is_cpu_tensor: + buffer[0].data = buffer[0].data.to(saved_dev) + return buffer[0] + + @staticmethod + def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, + pg: ProcessGroup) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return tensor + + assert tensor.device.type == "cuda", \ + "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ + f"collective function, however, we got {tensor.device.type} device" + + gather_dim = old_dist_spec.dims[0] + scatter_dim = dist_spec.dims[0] + shapes = list(tensor.shape) + scattered_dim_size = shapes[scatter_dim] // world_size + gathered_dim_size = shapes[gather_dim] * world_size + shapes[scatter_dim] = scattered_dim_size + + scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + output_ = torch.cat(gather_list, dim=gather_dim).contiguous() + assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size + return output_ + + @staticmethod + def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return tensor + + @staticmethod + def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return DistSpecManager._gather(tensor, old_dist_spec, pg) + + @staticmethod + def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + if old_dist_spec == dist_spec: + return tensor + if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1: + # use all-to-all to save memory + return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg) + tensor = DistSpecManager._gather(tensor, old_dist_spec, pg) + return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, + pg: ProcessGroup) -> torch.Tensor: + assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" + assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" + forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}') + if not DistSpecManager._use_autograd_function: + return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) + backward_trans_handle = getattr(DistSpecManager, + f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}') + return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, + backward_trans_handle) + + @staticmethod + @contextmanager + def no_grad(): + try: + DistSpecManager._use_autograd_function = False + yield + finally: + DistSpecManager._use_autograd_function = True diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py new file mode 100644 index 0000000000000000000000000000000000000000..0b62cbdda2c57eb0118932dc2276888d98e1aa15 --- /dev/null +++ b/colossalai/tensor/distspec.py @@ -0,0 +1,77 @@ +from enum import Enum +from typing import List + +__all__ = ['ReplicaSpec', 'ShardSpec'] + + +class DistPlacementPattern(Enum): + REPLICATE = 'r' + SHARD = 's' + + +class _DistSpec: + """_DistSpec + + A class indicates Distributed Specification. + The DistSpec is only works for the tensor parallel process groups. + Because the dist spec of data parallel process group can be automatically deduced. + This is an internal data structrue. + The API for users should be `ShardSpec` and `ReplicaSpec`. + + Args: + dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. + The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. + process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. + """ + + def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): + + self.placement = dist_placement_pattern + for k, v in meta_info.items(): + setattr(self, k, v) + + def __eq__(self, other: "_DistSpec") -> bool: + if dir(self) != dir(other): + return False + for attr in dir(self): + if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr): + return False + return True + + def __repr__(self) -> str: + res_list = ["DistSpec:"] + for attr in dir(self): + if not attr.startswith('__'): + res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}') + return ''.join(res_list) + + +def ReplicaSpec() -> _DistSpec: + """ReplicaSpec + + A distributed specification represents the tensor is replicated among the tensor parallel process group. + + Returns: + _DistSpec: an replicated dist spec instance. + """ + return _DistSpec(DistPlacementPattern.REPLICATE) + + +def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec: + """ShardSpec + + A distributed specification represents the tensor is sharded among the tensor parallel process group. + + Note: + Currently, only shard on one dimension is valid. In another word, dims should be of size 1. + + Args: + dims (List[int]): a list of dimensions + num_partitions (List[int]): a list of partition number of each dimensions. + + Returns: + _DistSpec: an shard dist spec instance. + """ + assert isinstance(dims, list) and isinstance(num_partitions, list) + assert len(dims) == len(num_partitions) + return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions)) diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/tensor/op_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1c00066f74655ef997a4e53e4fc1ce97d33f6434 --- /dev/null +++ b/colossalai/tensor/op_wrapper.py @@ -0,0 +1,53 @@ +from typing import ( + Callable, + Dict, +) +import functools + +# Custom sharded ops +_COLOSSAL_OPS: Dict[str, Callable] = {} + + +def _register_colo_op(op, func): + global _COLOSSAL_OPS + _COLOSSAL_OPS[op] = func + + +def colo_op_impl(func): + """ + Provides a way for users to write their own custom operator. This + can be used to override existing ColoTensor operators or write a new + one not supported by ColoTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ColoTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example: + >>> @colo_op_impl(torch.nn.functional.linear) + >>> def my_custom_linear(types, args, kwargs, process_group): + >>> .... + >>> + >>> input = torch.rand(10, 32) + >>> weight = ColoTensor(torch.rand(32, 16)) + >>> bias = ColoTensor(torch.rand(16)) + >>> # This will call `my_custom_linear` instead of the default. + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + + def decorator_sharded_func(wrapped_func): + _register_colo_op(func, wrapped_func) + + @functools.wraps(wrapped_func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrapper + + return decorator_sharded_func diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..2320d98bc36fdbd7e69df77761e04940201061a3 --- /dev/null +++ b/colossalai/tensor/param_op_hook.py @@ -0,0 +1,150 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, List, Tuple + +import torch + +from colossalai.tensor.colo_tensor import ColoTensor +from colossalai.tensor.tensor_spec import ColoTensorSpec + + +class ColoParamOpHook(ABC): + """ + Hook which is triggered by each operation when operands contain ColoParameter. + To customize it, you must inherit this abstract class, and implement ``pre_forward``, + ``post_forward``, ``pre_backward`` and ``post_backward``. + These four methods apply a list of ColoParameter as input args. + """ + + @abstractmethod + def pre_forward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def post_forward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def pre_backward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def post_backward(self, params: List[torch.Tensor]) -> None: + pass + + +class ColoParamOpHookManager: + """ + Manage your param op hooks. It only has static methods. + The only static method you should call is ``use_hooks(*hooks)``. + """ + hooks: Tuple[ColoParamOpHook, ...] = tuple() + + @staticmethod + @contextmanager + def use_hooks(*hooks: ColoParamOpHook): + """Change the param op hooks you use. Nested calling is allowed. + + Example: + >>> with ColoParamOpHookManager.use_hooks(*hooks): + >>> do_something() + >>> with ColoParamOpHookManager.use_hooks(): + >>> // clear hooks + >>> do_something() + """ + try: + old_param_op_hooks = ColoParamOpHookManager.hooks + ColoParamOpHookManager.hooks = hooks + yield + finally: + ColoParamOpHookManager.hooks = old_param_op_hooks + + @staticmethod + def _trigger_pre_forward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.pre_forward(params) + + @staticmethod + def _trigger_post_forward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.post_forward(params) + + @staticmethod + def _trigger_pre_backward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.pre_backward(params) + + @staticmethod + def _trigger_post_backward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.post_backward(params) + + @staticmethod + def pre_op(params: List[torch.Tensor], *args: Any) -> list: + ColoParamOpHookManager._trigger_pre_forward(params) + args_info = _get_colo_tensors_info(*args) + rets = PreFwdPostBwd.apply(params, *args) + return _update_colo_tensors(args_info, *rets) + + @staticmethod + def post_op(params: List[torch.Tensor], arg: Any) -> Any: + ColoParamOpHookManager._trigger_post_forward(params) + arg_info = _get_colo_tensors_info(arg) + ret = PostFwdPreBwd.apply(params, arg) + return _unpack_args(_update_colo_tensors(arg_info, ret)) + + @staticmethod + def has_hook() -> bool: + return len(ColoParamOpHookManager.hooks) > 0 + + +class PreFwdPostBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, *args): + ctx.params = params + return _unpack_args(args) + + @staticmethod + def backward(ctx, *grads): + ColoParamOpHookManager._trigger_post_backward(ctx.params) + return (None,) + grads + + +class PostFwdPreBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, args): + ctx.params = params + return args + + @staticmethod + def backward(ctx, *grads): + ColoParamOpHookManager._trigger_pre_backward(ctx.params) + return (None,) + grads + + +def _unpack_args(args): + if len(args) == 1: + return args[0] + return args + + +def _get_colo_tensors_info(*args) -> list: + info = [] + for arg in args: + if isinstance(arg, ColoTensor): + info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec))) + else: + info.append(None) + return info + + +def _update_colo_tensors(info, *args) -> list: + ret = [] + for t_info, arg in zip(info, args): + if t_info is not None: + t_cls, spec = t_info + arg = t_cls.from_torch_tensor(arg, spec=spec) + ret.append(arg) + return ret diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e565071e580c6fb3adc63ddb3e8304847a2983 --- /dev/null +++ b/colossalai/tensor/process_group.py @@ -0,0 +1,311 @@ +import torch +from typing import List, Optional +from colossalai.logging import get_dist_logger +from colossalai.context.singleton_meta import SingletonMeta + + +class PyTorchProcessGroupDict(metaclass=SingletonMeta): + + def __init__(self): + # distributed settings + self.dict = {} + + def get(self, rank_list: List[int], backend: str = 'nccl'): + """Reuse Pytorch ProcessGroup when such a group is initialized + """ + rank_tuple = tuple(rank_list) + # we need to convert the passed list to a tuple + # since List is unhashable + pg_key = (backend, rank_tuple) + + if pg_key not in self.dict: + + self.logger = get_dist_logger('ProcessGroup') + self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0]) + self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) + return self.dict[pg_key] + + +PYTORCHPGDICT_ = PyTorchProcessGroupDict() + + +class ProcessGroup: + """ProcessGroup + Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism. + + NOTE, the ProcessGroup must be used after `torch.distributed.initialize()` + + + Args: + rank: the global rank of the current process. + ranks: List[int], a list of rank id belongings to this process group. + backend: str, the backend of the process group. + tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1. + dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). + """ + + def __init__(self, + rank: Optional[int] = None, + ranks: Optional[List[int]] = None, + tp_degree: Optional[int] = None, + dp_degree: Optional[int] = None) -> None: + if not torch.distributed.is_initialized(): + self.is_init = False + return + + assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" + if rank is None: + self._rank = torch.distributed.get_rank() + else: + self._rank = rank + + if ranks is None: + self._rank_list = list(range(torch.distributed.get_world_size())) + else: + self._rank_list = ranks + self._rank_list.sort() # ensure that the list is in order + + self._world_size = len(self._rank_list) + + if dp_degree is None and tp_degree is None: + self._dp_degree = self._world_size + self._tp_degree = 1 + elif dp_degree and not tp_degree: + self._dp_degree = dp_degree + assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" + self._tp_degree = self._world_size // dp_degree + elif not dp_degree and tp_degree: + self._tp_degree = tp_degree + assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" + self._dp_degree = self._world_size // tp_degree + else: + self._dp_degree = dp_degree + self._tp_degree = tp_degree + assert self._dp_degree * self._tp_degree == self._world_size, \ + f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \ + f"and TP degree {self._tp_degree}" + + self._tp_rank_list = None + self._dp_rank_list = None + + for i in range(self._dp_degree): + i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] + PYTORCHPGDICT_.get(i_tp_list, 'nccl') + if self._rank in i_tp_list: + self._tp_rank_list = i_tp_list + + for j in range(self._tp_degree): + j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] + PYTORCHPGDICT_.get(j_dp_list, 'nccl') + if self._rank in j_dp_list: + self._dp_rank_list = j_dp_list + + self._has_cpu_groups = False + self.is_init = True + + def set_cpu_groups(self): + """set_cpu_groups + Initialize Pytorch process groups for cpu communications. + """ + if self.has_cpu_groups: + return + + for i in range(self._dp_degree): + i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] + PYTORCHPGDICT_.get(i_tp_list, 'gloo') + + for j in range(self._tp_degree): + j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] + PYTORCHPGDICT_.get(j_dp_list, 'gloo') + + self._has_cpu_groups = True + + @property + def has_cpu_groups(self) -> bool: + """has_cpu_groups + If cpu groups have been initailized. + + Returns: + bool: cpu process groups have been initialized or not. + """ + return self._has_cpu_groups + + def __repr__(self): + if self.is_init: + return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ + format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) + else: + return "ProcessGroup not initialized" + + def __eq__(self, obj: 'ProcessGroup') -> bool: + if not isinstance(obj, ProcessGroup): + return False + if self._rank != obj._rank: + return False + if self._rank_list != obj._rank_list: + return False + if self._tp_rank_list != obj._tp_rank_list: + return False + if self._dp_rank_list != obj._dp_rank_list: + return False + if self._tp_degree != obj._tp_degree: + return False + if self._dp_degree != obj._dp_degree: + return False + return True + + def rank(self) -> int: + """rank + + The current rank in the global process group. + + Returns: + int: the rank number + """ + return self._rank + + def ranks_in_group(self) -> List[int]: + """ranks_in_group + + a list of rank number in in the global process group. + + Returns: + List[int]: a list of rank number. + """ + return self._rank_list + + def world_size(self) -> int: + """world_size + + The world size of the global process group. + + Returns: + int: world size + """ + return self._world_size + + def tp_rank_list(self) -> List[int]: + """tp_rank_list + + the rank list in the TP process group containing the current rank. + + Returns: + List[int]: the list of rank number. + """ + return self._tp_rank_list + + def dp_rank_list(self) -> List[int]: + """dp_rank_list + + the rank list in the DP process group containing the current rank. + + Returns: + List[int]: the list of rank number. + """ + return self._dp_rank_list + + def tp_local_rank(self) -> int: + """tp_local_rank + + The local rank number in the current TP process group. + + Returns: + int: tp rank number. + """ + return self._rank % self._tp_degree + + def dp_local_rank(self) -> int: + """dp_local_rank + + The local rank number in the current DP process group. + + Returns: + int: dp rank number. + """ + return self._rank // self._tp_degree + + def dp_world_size(self) -> int: + """dp_world_size + + The world size of the current DP process group. + + Returns: + int: dp world size + """ + return len(self._dp_rank_list) + + def tp_world_size(self) -> int: + """tp_world_size + + The world size of the current TP process group. + + Returns: + int: tp world size + """ + return len(self._tp_rank_list) + + def dp_process_group(self): + """dp_process_group + + the pytorch DP process group containing the current rank. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. + """ + return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + + def tp_process_group(self): + """tp_process_group + + the pytorch TP process group containing the current rank. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. + """ + return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + + def cpu_dp_process_group(self): + """cpu_dp_process_group + + the pytorch CPU DP process group containing the current rank. + + assert failed if cpu process group is not initialized. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. + """ + assert self._has_cpu_groups + return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') + + def cpu_tp_process_group(self): + """cpu_tp_process_group + + the pytorch CPU TP process group containing the current rank. + + assert failed if cpu process group is not initialized. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. + """ + assert self._has_cpu_groups + return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + + def get_ranks_in_dp(self) -> List[int]: + """get_ranks_in_dp + + ranks in current dp process group. + + Returns: + List[int]: a list of rank number. + """ + return self._dp_rank_list + + def get_ranks_in_tp(self): + """get_ranks_in_tp + + ranks in current tp process group. + + Returns: + List[int]: a list of rank number. + """ + return self._tp_rank_list diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..d566e35158432fb89abe920b2b6c8b119cd6a881 --- /dev/null +++ b/colossalai/tensor/shape_consistency.py @@ -0,0 +1,583 @@ +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator + +from .comm_spec import * + +__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] + + +@dataclass +class ShapeConsistencyOptions: + """ + ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. + """ + # TODO: shape consistency option is not implemented yet + pass + + +def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: + shape_consistency_manager = ShapeConsistencyManager() + global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) + with torch.no_grad(): + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, + global_sharding_spec) + return global_tensor + + +def set_shape_consistency_options(options: ShapeConsistencyOptions): + """ + Configure the shape consistency manager via function call. + """ + manager = ShapeConsistencyManager() + manager.options = options + + +class ShapeConsistencyManager(metaclass=SingletonMeta): + + def __init__(self): + self._options = None + self._forward_only = False + self.total_communication_cost = 0 + self.total_transform_steps = 0 + self.cached_spec_pairs_transform_path = {} + + @property + def options(self): + return self._options + + @options.setter + def options(self, options_: ShapeConsistencyOptions): + assert isinstance(options_, ShapeConsistencyOptions) + self._options = options_ + + @property + def forward_only(self): + return self._forward_only + + @forward_only.setter + def forward_only(self, value): + assert isinstance(value, bool) + self._forward_only = value + + def get_all_all_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + Get all valid sharding specs from source_spec with single all-gather operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the all-gather operation, we just care about the S dimension. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(Dict[str, float]): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. + + Example: + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: R,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + for target_pair in source_spec.dim_partition_dict.items(): + shard_list = all_gather_simulator(target_pair) + index = target_pair[0] + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if shard_list: + new_dim_partition_dict[index] = shard_list + else: + new_dim_partition_dict.pop(index) + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + gather_dim = index + logical_process_axis = target_pair[1][-1] + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + # shard_dim will be used during backward + shard_dim=gather_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only) + + # compute the communication cost with CommSpec + cost_dict = comm_spec.get_comm_cost() + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + return valid_spec_dict + + def get_all_all_to_all_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + Get all valid sharding specs from source_spec with single all-to-all operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the all-to-all operation, we just care about the pairs containing S dimension. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(Dict[str, float]): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + + Example: + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: R,S1,S0 + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + # skip (R, R) cases + if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S01, R) -> (R, S01) is NOT allowed + if len(source_spec.dim_partition_dict[f_index]) >= 2: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S01) -> (S01, R) is NOT allowed + if len(source_spec.dim_partition_dict[b_index]) >= 2: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + # skip (S1, S0) -> S10 + if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: + continue + f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) + f_index = f_target_pair[0] + b_index = b_target_pair[0] + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + if len(f_shard_list) < len(f_target_pair[1]): + gather_dim = f_index + shard_dim = b_index + logical_process_axis = f_target_pair[1][-1] + else: + gather_dim = b_index + shard_dim = f_index + logical_process_axis = b_target_pair[1][-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only) + + # compute the communication cost with CommSpec + cost_dict = comm_spec.get_comm_cost() + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if f_shard_list: + new_dim_partition_dict[f_index] = f_shard_list + else: + new_dim_partition_dict.pop(f_index) + if b_shard_list: + new_dim_partition_dict[b_index] = b_shard_list + else: + new_dim_partition_dict.pop(b_index) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + + return valid_spec_dict + + def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): + ''' + Get all valid sharding specs from source_spec with single shard operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the sharding operation, we just care about legal sharding dimensions. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + + Example: + dim_partition_dict = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + + # legal sharding dims means the mesh_id is still available to use. + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + for dim, shard_list in source_spec.dim_partition_dict.items(): + for element in shard_list: + legal_sharding_dims.remove(element) + if len(legal_sharding_dims) == 0: + return valid_spec_dict + + tensor_dims = len(source_spec.entire_shape) + + for index in range(tensor_dims): + if index not in source_spec.dim_partition_dict: + shard_list_list = shard_simulator((index, []), legal_sharding_dims) + else: + shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) + if not shard_list_list: + continue + for shard_list in shard_list_list: + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[index] = shard_list + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + shard_dim = index + logical_process_axis = shard_list[-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only) + + # compute the communication cost with CommSpec + cost_dict = comm_spec.get_comm_cost() + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + return valid_spec_dict + + def get_all_mix_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + S0S1 -> RR + S1S0 -> RR + S01R -> RR + RS01 -> RR + ''' + valid_spec_dict = {} + comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict): + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S10, R) -> (R, R) + if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S10) -> (R, R) + if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + comm_spec = CommSpec(comm_pathern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True) + cost_dict = comm_spec.get_comm_cost() + new_dim_partition_dict = {} + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + + return valid_spec_dict + + def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: + ''' + Get all valid sharding specs from source_spec with one step transform, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + Note: + all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, + and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, + we could safely put them together. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + ''' + valid_spec_dict = {} + valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) + valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) + valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) + return valid_spec_dict + + def shape_consistency(self, source_spec: ShardingSpec, + target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: + ''' + This method will find a path to transform source_spec to target_spec with + a greedy algorithm. + The basic idea is: + Step1: + Generate all one-step transform sequences from source_spec. + Step2: + Pick the 'best' sharding spec following the heuristic function. + Step3: + Repeat above steps until the source spec transform to target spec. + + During finding the transform path, commucation cost will be accumulated, and it + will be finally used in auto parallel solver. + + Additionally, to avoid repeating the path search in runtime, we cached all solved path + in auto parallel strategy building time, which could handle most of cases in runtime. + + Argument: + source_spec(ShardingSpec): ShardingSpec of the source activation. + target_spec(ShardingSpec): ShardingSpec of the target activation. + + Return: + transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec, + it contains the source_spec and target_spec. + comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order. + total_cost(float): total cost to complete shape consistency transform. + + Example: + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target) + print(f'transform_path: {transform_path}') + print(f'comm_action_sequence: {comm_action_sequence}') + print(f'total_cost: {total_cost}') + + output: + transform_path: [DistSpec: + shard_sequence: R,S01,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: R,S0,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4)] + comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), + CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), + CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] + total_cost: 12294.402000000002 + ''' + MAX_TRANSFORM_STEPS = 20 + total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} + total_steps = 0 + transform_path = [] + comm_action_sequence = [] + spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + self.cached_spec_pairs_transform_path[spec_pairs] = (None, None) + + # We do nothing if the sharding spec is all the same. + if source_spec.sharding_sequence_difference(target_spec) == 0: + self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence, total_cost_dict) + + temp_sharding_spec = source_spec + + transform_path.append(temp_sharding_spec) + # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms + while total_steps <= MAX_TRANSFORM_STEPS: + valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict) + best_difference_score = math.inf + + for sharding_spec, info_pairs in valid_transform_spec_dict.items(): + comm_spec, cost_dict = info_pairs + spec_difference = sharding_spec.sharding_sequence_difference(target_spec) + + if spec_difference == 0: + for phase, cost in total_cost_dict.items(): + total_cost_dict[phase] = cost + cost_dict[phase] + transform_path.append(sharding_spec) + comm_action_sequence.append(comm_spec) + self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence, total_cost_dict) + + if spec_difference < best_difference_score: + temp_sharding_spec = sharding_spec + temp_cost_dict = cost_dict + temp_comm_spec = comm_spec + best_difference_score = spec_difference + + transform_path.append(temp_sharding_spec) + comm_action_sequence.append(temp_comm_spec) + for phase, cost in total_cost_dict.items(): + total_cost_dict[phase] = cost + temp_cost_dict[phase] + total_steps += 1 + + raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") + + def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: + ''' + Apply target_spec to tensor with source sharding spec, the transform path is generated by the + shape_consistency method. + + Argument: + tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec. + target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec. + + Example: + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = torch.Size((4, 2)) + shape_consistency_manager = ShapeConsistencyManager() + dim_partition_source = {0: [0]} + dim_partition_target = {1: [0]} + + # DistSpec: + # shard_sequence: S0,R + # device_mesh_shape: (2, 2) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 2) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + tensor_to_comm.sharding_spec = sharding_spec_source + shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) + print(tensor_to_comm) + + Output in rank0 and rank2: + tensor([[0.], + [0.], + [2.], + [2.]]) + + Output in rank1 and rank3: + tensor([[1.], + [1.], + [3.], + [3.]]) + ''' + _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) + for comm_spec in comm_action_sequence: + tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec) + tensor_with_sharding_spec.sharding_spec = target_spec + return tensor_with_sharding_spec + + def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec): + _, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec) + for comm_spec in comm_action_sequence: + tensor = comm_spec.covert_spec_to_action(tensor) + tensor.sharding_spec = target_spec + return tensor diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd0338850cf5bcaefbbb4dabdc7b79022564069 --- /dev/null +++ b/colossalai/tensor/sharding_spec.py @@ -0,0 +1,296 @@ +import operator +from copy import deepcopy +from functools import reduce + +import torch + +from colossalai.device.device_mesh import DeviceMesh + +from .utils import merge_same_dim_mesh_list + +__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] + +ALLGATHER_COST = 20 +SHARD_COST = 5 +STEP_PENALTY = 6 +NAN = 'nan' + + +class _DimSpec: + ''' + Sharding spec for single dimension of the sharded tensor decribe the sharding dimension of + logical device mesh and give a method to compute the difference between them. + This class is used internally in ShardingSpec. + + Argument: + shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. + Otherwise, the element in shard_list means the data will be sharded in that dimension. + ''' + + def __init__(self, shard_list): + self.is_replica = len(shard_list) == 0 + self.shard_list = shard_list + self.build_difference_2d_dict() + + def __eq__(self, other): + return str(self) == str(other) + + def __repr__(self): + if self.is_replica: + return 'R' + target = 'S' + for dim in self.shard_list: + target += str(dim) + return target + + def _convert_str_to_shard_list(self, str_spec): + ''' + Conver str_spec into shard_list. + + Argument: + str_spec(str): dim spec in str type. + ''' + + if str_spec == 'R': + return [] + if str_spec == 'S0': + return [0] + if str_spec == 'S1': + return [1] + if str_spec == 'S01': + return [0, 1] + + def build_difference_2d_dict(self): + ''' + Build a difference maping for 2D device mesh case. It will be used to + compute the difference between DimSpec pairs. + ''' + + source_spec_list = ['R', 'S0', 'S1', 'S01'] + target_spec_list = ['R', 'S0', 'S1', 'S01'] + difference_dict = {} + for source_spec in source_spec_list: + for target_spec in target_spec_list: + legal_sharding_dims = [] + spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) + source_shard_list = self._convert_str_to_shard_list(source_spec) + target_shard_list = self._convert_str_to_shard_list(target_spec) + + # source same as target + if source_shard_list == target_shard_list: + difference = 0 + + # all_gather(source) -> target + elif len(source_shard_list + ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + difference = ALLGATHER_COST + + # shard(source) -> target + elif len(source_shard_list) == len( + target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ + -1] not in source_shard_list: + difference = SHARD_COST + + # S1 -> S0 or S0 -> S1 + elif len(source_shard_list) == len(target_shard_list): + # source -> R -> target + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + # R -> S01 + elif len(source_shard_list) == len(target_shard_list) - 2: + difference = SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> R + elif len(source_shard_list) == len(target_shard_list) + 2: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + + # S1 -> S01 + elif len(source_shard_list) == len(target_shard_list) - 1: + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> S1 + elif len(source_shard_list) == len(target_shard_list) + 1: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + else: + difference = NAN + difference_dict[spec_pair] = difference + + self.difference_dict = difference_dict + + def difference(self, other): + ''' + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + ''' + difference = self.difference_dict[(str(self), str(other))] + return difference + + +class ShardingSpecException(Exception): + pass + + +class ShardingOutOfIndexError(ShardingSpecException): + pass + + +class DuplicatedShardingDimensionError(ShardingSpecException): + pass + + +class ShardingNotDivisibleError(ShardingSpecException): + pass + + +class ShardingSpec: + ''' + Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong + to, the entire shape of the tensor before sharded, and the sharding sequence looks like + [R, R, S0, S1]. + + Argument: + device_mesh(DeviceMesh): A logical view of a physical mesh. + entire_shape(torch.Size): The entire shape of tensor before sharded. + dim_partition_dict(Dict[int, List[int]]๏ผŒ optional): The key is the dimension of tensor to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. + ''' + + def __init__(self, + device_mesh: DeviceMesh, + entire_shape: torch.Size, + dim_partition_dict=None, + sharding_sequence=None): + self.device_mesh = device_mesh + + if isinstance(entire_shape, (list, tuple)): + entire_shape = torch.Size(entire_shape) + self.entire_shape = entire_shape + self.dim_partition_dict = dim_partition_dict + self.sharding_sequence = sharding_sequence + if self.sharding_sequence is None: + assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' + self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), + dim_partition_dict=self.dim_partition_dict) + self.convert_dict_to_shard_sequence() + elif self.dim_partition_dict is None: + assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + self.convert_shard_sequence_to_dict() + self._sanity_check() + + def __repr__(self): + res_list = ["DistSpec:"] + res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) + res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") + return ' '.join(res_list) + + def _sanity_check(self): + # make sure all axes in logical device mesh only be used once + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in self.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + + # make sure that the dimension is not out of index + for dim in self.dim_partition_dict.keys(): + if dim >= len(self.entire_shape): + raise ShardingOutOfIndexError( + f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions" + ) + + # make sure that the sharding for a dimension is divisible by the number of devices + for dim, shard_list in self.dim_partition_dict.items(): + tensor_dim_size = self.entire_shape[dim] + num_devices = 1 + + for element in shard_list: + num_devices *= self.device_mesh.mesh_shape[element] + + if tensor_dim_size % num_devices != 0: + raise ShardingNotDivisibleError( + f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + ) + + def convert_dict_to_shard_sequence(self): + ''' + Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. + ''' + sharding_sequence = [_DimSpec([])] * len(self.entire_shape) + for dim, shard_list in self.dim_partition_dict.items(): + sharding_sequence[dim] = _DimSpec(shard_list) + self.sharding_sequence = sharding_sequence + + def convert_shard_sequence_to_dict(self): + ''' + Convert sharding_sequence into dim_partition_dict. + ''' + new_dim_partition_dict = {} + for index, dim_spec in enumerate(self.sharding_sequence): + if not dim_spec.is_replica: + if index not in new_dim_partition_dict: + new_dim_partition_dict[index] = [] + new_dim_partition_dict[index].extend(dim_spec.shard_list) + self.dim_partition_dict = new_dim_partition_dict + + def sharding_sequence_difference(self, other): + ''' + This function is a naive version of difference computation. It just simply accumulates difference every dimension between the + pair of sharding sequence. + + Example: + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + dim_partition_dict_to_compare = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) + print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) + + Output: + 25 + + Argument: + other(ShardingSpec): The ShardingSpec to compared with. + + Return: + difference(int): Difference between two ShardingSpec. + ''' + assert len(self.sharding_sequence) == len( + other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + difference = 0 + for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): + difference += orig_dim_spec.difference(other_dim_spec) + return difference + + def get_sharded_shape_per_device(self): + + sharded_shape = list(self.entire_shape) + for dim, shard_list in self.dim_partition_dict.items(): + mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + shard_partitions = reduce(operator.mul, mesh_list, 1) + assert sharded_shape[ + dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + sharded_shape[dim] //= shard_partitions + return torch.Size(sharded_shape) diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..580df9f8f31023f0623cadc08fe67d49b309819f --- /dev/null +++ b/colossalai/tensor/tensor_spec.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Optional + +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor.process_group import ProcessGroup + +from .compute_spec import ComputeSpec + + +@dataclass +class ColoTensorSpec: + """ ColoTensorSpec + + A data class for specifications of the `ColoTensor`. + It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. + The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. + """ + pg: ProcessGroup + dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) + compute_attr: Optional[ComputeSpec] = None diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2ead630d59705032b42702a298180267879da3 --- /dev/null +++ b/colossalai/tensor/utils.py @@ -0,0 +1,226 @@ +from typing import Dict, Iterator, List, Tuple, Union + +import torch +import torch.nn as nn + +from colossalai.tensor.colo_tensor import ColoTensor + + +def all_gather_simulator(target_pair): + ''' + Simulating all-gather operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. + Therefore, all gather operation just remove the last element in shard list, + e.g.: + all-gather(S01) -> S0 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + new_shard_list = shard_list[:-1] + + return new_shard_list + + +def all_to_all_simulator(f_target_pair, b_target_pair): + ''' + Simulating all-to-all operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We BANNED all representations which shard_list in decreasing order, + such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. + Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + e.g.: + all-to-all(S0, S1) -> [S01, R] + all-to-all(S0, R) -> [R, S0] + Otherwise, we extend the front shard_list to behind. + e.g.: + all-to-all(R, S1) -> [S1, R] + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, f_shard_list = f_target_pair + _, b_shard_list = b_target_pair + if not len(b_shard_list): + b_shard_list.extend(f_shard_list) + f_shard_list = [] + else: + f_shard_list.extend(b_shard_list) + b_shard_list = [] + + return f_shard_list, b_shard_list + + +def shard_simulator(target_pair, legal_sharding_dims): + ''' + Simulating shard operation, analyze the communication cost(always ZERO) + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. + In addition, We BANNED all representations which shard_list in decreasing order, + such as S10, so shard(S0) -> S10 is NOT allowed. + Therefore, for the R dimension, we could just append any legal sharding dim on it. + e.g.: + shard(R) -> S0 + For the S dimension, we need to make sure the shard_list after sharding still keep rising order. + e.g: + shard(S0) -> S01 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + shard_list_list = [] + for dim in legal_sharding_dims: + if len(shard_list) != 0 and dim <= shard_list[-1]: + continue + new_shard_list = shard_list + [dim] + shard_list_list.append(new_shard_list) + + return shard_list_list + + +def mix_gather_simulator(f_target_pair, b_target_pair): + ''' + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) + S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) + S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1) + RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) + S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) + RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) + ''' + if f_target_pair[1] and b_target_pair[1]: + leading_dim = b_target_pair[1] > f_target_pair[1] + return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] + if f_target_pair[1]: + leading_dim = f_target_pair[1][0] < f_target_pair[1][1] + return [ + f_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + if b_target_pair[1]: + leading_dim = b_target_pair[1][0] < b_target_pair[1][1] + return [ + b_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + + +# The function is credited to PyTorch Team +def named_params_with_colotensor( + module: nn.Module, + prefix: str = '', + recurse: bool = True, +) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: + r"""Returns an iterator over module parameters (together with the + ColoTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:torchshard._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (string, Union[Tensor, ColoTensor]): Tuple containing + the name and parameter (or ColoTensor parameter) + + Example: + + >>> model = torch.nn.Linear(*linear_size) + >>> delattr(model.weight) + >>> setattr(model.weight, ColoTensor(...)) + >>> for name, param in named_params_with_colotensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ColoTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ('.' if mod_prefix else '') + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val + + +def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: + return ColoTensor(tensor) + + +def convert_parameter(module: torch.nn.Module, param_name: str): + # Perform some validation first. + if not hasattr(module, param_name): + raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError( + f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + + if not tensor.is_contiguous(): + raise ValueError(f'param: {param_name} is not a contiguous Tensor') + + st = _convert_tensor(tensor) + + # Replace param with ColoTensor. + + # Need to delete the attribute first since param_name might be + # torch.nn.Parameter and can't be replaced with ColoTensor which is + # not torch.nn.Parameter. + delattr(module, param_name) + + # Now we can set the attribute appropriately. + setattr(module, param_name, st) + + +def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to convert the negative dim value to positive. + ''' + dims_to_convert = [] + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dims_to_convert.append(dim) + for dim in dims_to_convert: + dim_partition_dict.pop(dim) + dim_partition_dict[dim_size + dim] = mesh_list + return dim_partition_dict + + +def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to merge the different key value which points to same physical position. + + For example: + dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. + In this method, above dim_partition_dict will be converted to {1: [0, 1]} + ''' + converted_dim_partition_dict = {} + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dim = dim_size + dim + if dim not in converted_dim_partition_dict: + converted_dim_partition_dict[dim] = mesh_list + else: + converted_dim_partition_dict[dim].extend(mesh_list) + + return converted_dim_partition_dict diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3dd500dea8efc9393bd808fa71bf3f5bd171504 --- /dev/null +++ b/colossalai/testing/__init__.py @@ -0,0 +1,7 @@ +from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group +from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus + +__all__ = [ + 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', + 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus' +] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..de4f460c0bebefd4bf708c1047ff4022f844105a --- /dev/null +++ b/colossalai/testing/comparison.py @@ -0,0 +1,33 @@ +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + + +def assert_equal(a: Tensor, b: Tensor): + assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + + +def assert_not_equal(a: Tensor, b: Tensor): + assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + + +def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8): + assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}' + + +def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): + assert_close(a, b, rtol, atol) + + +def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i + 1] + assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}' diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a472eb3723ec1e8ce8705aa978355a460179a328 --- /dev/null +++ b/colossalai/testing/pytest_wrapper.py @@ -0,0 +1,40 @@ +""" +This file will not be automatically imported by `colossalai.testing` +as this file has a dependency on `pytest`. Therefore, you need to +explicitly import this file `from colossalai.testing.pytest_wrapper import `.from +""" + +import pytest +import os + + +def run_on_environment_flag(name: str): + """ + Conditionally run a test based on the environment variable. If this environment variable is set + to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0. + + Args: + name (str): the name of the environment variable flag. + + Usage: + # in your pytest file + @run_on_environment_flag(name='SOME_FLAG') + def test_for_something(): + do_something() + + # in your terminal + # this will execute your test + SOME_FLAG=1 pytest test_for_something.py + + # this will skip your test + pytest test_for_something.py + + """ + assert isinstance(name, str) + flag = os.environ.get(name.upper(), '0') + + reason = f'Environment varialbe {name} is {flag}' + if flag == '1': + return pytest.mark.skipif(False, reason=reason) + else: + return pytest.mark.skipif(True, reason=reason) diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6d24a4b94b152f20f404f5b1bb1cae2d74a0b5 --- /dev/null +++ b/colossalai/testing/random.py @@ -0,0 +1,19 @@ +import random + +import numpy as np +import torch + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64c1d6e7bcd0486f4fdd32a9b91a03af4d5addcc --- /dev/null +++ b/colossalai/testing/utils.py @@ -0,0 +1,204 @@ +import re +import torch +from typing import Callable, List, Any +from functools import partial +from inspect import signature +from packaging import version + + +def parameterize(argument: str, values: List[Any]) -> Callable: + """ + This function is to simulate the same behavior as pytest.mark.parameterize. As + we want to avoid the number of distributed network initialization, we need to have + this extra decorator on the function launched by torch.multiprocessing. + + If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments, + positioanl arguments are not allowed. + + Usgae:: + + # Example 1: + @parameterize('person', ['xavier', 'davis']) + def say_something(person, msg): + print(f'{person}: {msg}') + + say_something(msg='hello') + + # This will generate output: + # > xavier: hello + # > davis: hello + + # Exampel 2: + @parameterize('person', ['xavier', 'davis']) + @parameterize('msg', ['hello', 'bye', 'stop']) + def say_something(person, msg): + print(f'{person}: {msg}') + + say_something() + + # This will generate output: + # > xavier: hello + # > xavier: bye + # > xavier: stop + # > davis: hello + # > davis: bye + # > davis: stop + + Args: + argument (str): the name of the argument to parameterize + values (List[Any]): a list of values to iterate for this argument + """ + + def _wrapper(func): + + def _execute_function_by_param(**kwargs): + for val in values: + arg_map = {argument: val} + partial_func = partial(func, **arg_map) + partial_func(**kwargs) + + return _execute_function_by_param + + return _wrapper + + +def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 5) -> Callable: + """ + A decorator on a function to re-run when an exception occurs. + + Usage:: + + # rerun for all kinds of exception + @rerun_on_exception() + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for RuntimeError only + @rerun_on_exception(exception_type=RuntimeError) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for maximum 10 times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=10) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=None) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun only the exception message is matched with pattern + # for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + Args: + exception_type (Exception, Optional): The type of exception to detect for rerun + pattern (str, Optional): The pattern to match the exception message. + If the pattern is not None and matches the exception message, + the exception will be detected for rerun + max_try (int, Optional): Maximum reruns for this function. The default value is 5. + If max_try is None, it will rerun foreven if exception keeps occurings + """ + + def _match_lines(lines, pattern): + for line in lines: + if re.match(pattern, line): + return True + return False + + def _wrapper(func): + + def _run_until_success(*args, **kwargs): + try_count = 0 + assert max_try is None or isinstance(max_try, int), \ + f'Expected max_try to be None or int, but got {type(max_try)}' + + while max_try is None or try_count < max_try: + try: + try_count += 1 + ret = func(*args, **kwargs) + return ret + except exception_type as e: + error_lines = str(e).split('\n') + if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): + print('Exception is caught, retrying...') + # when pattern is not specified, we always skip the exception + # when pattern is specified, we only skip when pattern is matched + continue + else: + print('Maximum number of attempts is reached or pattern is not matched, no more retrying...') + raise e + + # Override signature + # otherwise pytest.mark.parameterize will raise the following error: + # function does not use argumetn xxx + sig = signature(func) + _run_until_success.__signature__ = sig + + return _run_until_success + + return _wrapper + + +def rerun_if_address_is_in_use(): + """ + This function reruns a wrapped function if "address already in use" occurs + in testing spawned with torch.multiprocessing + + Usage:: + + @rerun_if_address_is_in_use() + def test_something(): + ... + + """ + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + + # only torch >= 1.8 has ProcessRaisedException + if torch_version.minor >= 8: + exception = torch.multiprocessing.ProcessRaisedException + else: + exception = Exception + + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*") + return func_wrapper + + +def skip_if_not_enough_gpus(min_gpus: int): + """ + This function is used to check the number of available GPUs on the system and + automatically skip the test cases which require more GPUs. + + Note: + The wrapped function must have `world_size` in its keyword argument. + + Usage: + @skip_if_not_enough_gpus(min_gpus=8) + def test_something(): + # will be skipped if there are fewer than 8 GPUs available + do_something() + + Arg: + min_gpus (int): the minimum number of GPUs required to run this test. + """ + + def _wrap_func(f): + + def _execute_by_gpu_num(*args, **kwargs): + num_avail_gpu = torch.cuda.device_count() + if num_avail_gpu >= min_gpus: + f(*args, **kwargs) + + return _execute_by_gpu_num + + return _wrap_func diff --git a/colossalai/trainer/__init__.py b/colossalai/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84e53dc4e87ac5b10a93aacc0fce975cc49c66eb --- /dev/null +++ b/colossalai/trainer/__init__.py @@ -0,0 +1,3 @@ +from ._trainer import Trainer + +__all__ = ['Trainer'] diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..60bbc4eeee32adbd46472106fff4a07c653f7cb6 --- /dev/null +++ b/colossalai/trainer/_trainer.py @@ -0,0 +1,409 @@ +from typing import Union, List, Any + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from colossalai.engine import Engine +from colossalai.logging import DistributedLogger +from colossalai.utils import MultiTimer +from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage +from colossalai.trainer.hooks import BaseHook + + +class Trainer: + r"""This is a class tending for easy deployments of users' training and evaluation instead of + writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is + called `Trainer`. + + Args: + engine (:class:`Engine`): Engine responsible for the process function. + timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training. + logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log. + + + Examples: + >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training + >>> model = ... + >>> criterion = ... + >>> optimizer = ... + >>> train_dataloader = ... + >>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler + >>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion) + >>> # Beginning training progress + >>> timier = ... + >>> logger = ... + >>> trainer = Trainer(engine=engine, logger=logger, timer=timier) + >>> # add hooks you would like to use here. + >>> hook_list = [] + >>> trainer.fit( + >>> train_dataloader=train_dataloader, + >>> epochs=gpc.config.NUM_EPOCHS, + >>> test_interval=1, + >>> hooks=hook_list, + >>> display_progress=True, + >>> return_output_label=False + >>> ) + + More examples and details could be found in + `Training with engine and trainer `_ + and `ColossalAI-Examples `_. + """ + + def __init__( + self, + engine: Engine, + timer: MultiTimer = None, + logger: DistributedLogger = None, + ): + # training-ralated params + self._engine = engine + self._max_epochs = 0 + self._cur_epoch = 0 + self._max_steps = 0 + self._cur_step = 0 + self._steps_per_epoch = 0 + + # misc params + self._logger = logger + self._verbose = logger is not None + + # hooks can store states in this dict, and could be consumed by other hooks + self.states = dict() + + # build hooks + self.hooks = list() + + # multi-timer for time benchmarking + self._timer = timer + + @property + def cur_epoch(self): + """Returns the index of the current epoch.""" + return self._cur_epoch + + @cur_epoch.setter + def cur_epoch(self, epoch: int): + """Set how many epochs have been processed.""" + # allow setter for training resumption + self._cur_epoch = epoch + + @property + def cur_step(self): + """Returns how many iteration steps have been processed.""" + return self._cur_step + + @property + def max_epochs(self): + return self._max_epochs + + @property + def max_steps(self): + return self._max_steps + + @property + def steps_per_epoch(self): + return self._steps_per_epoch + + @property + def engine(self): + return self._engine + + def _set_current_step(self, epoch: int): + """Sets current step number. + + Args: + epoch (int): Step number to be set. + """ + self._cur_step = epoch * self._steps_per_epoch + + def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: + """Call timer funciton with a given timer name. + + Args: + action (str): Function to be called on timer. + item (str): Name of the timer. + args (list): args used for action function. + kwargs (dict): kwargs used for action function. + """ + + if self._timer is not None: + getattr(self._timer, action)(item, *args, **kwargs) + + def _reset_states(self) -> None: + """Clear trainer states""" + self.states = dict() + + def _call_hooks(self, func, output=None): + """Calls specific hooks in the current time point. + + Args: + func (str): A string represents the time point. + output (Any, optional): Output of the model after running an iteration or None in any other time points. + """ + # Only after iter hook will receive output + for hook in self.hooks: + if output is None: + getattr(hook, func)(self) + else: + getattr(hook, func)(self, *output) + + @staticmethod + def _should_display_progress(display_progress: bool): + """Only display progress on DP rank 0, TP rank 0 and PP last rank""" + return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) + + def _train_epoch( + self, + train_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + # set training state + self._engine.train() + data_iter = iter(train_dataloader) + progress = range(self._steps_per_epoch) + if display_progress: + if epoch is None: + progress = tqdm(progress, desc="[Train]") + else: + progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]") + + self._call_hooks("before_train_epoch") + self._call_timer(action="start", item="Train-epoch") + for i in progress: + self._call_hooks("before_train_iter") + self._call_timer(action="start", item="Train-step") + + # run 1 training step + self.engine.zero_grad() + logits, label, loss = self.engine.execute_schedule( + data_iter, + forward_only=False, + return_loss=True, + return_output_label=return_output_label, + ) + self.engine.step() + self._call_timer(action="stop", item="Train-step", keep_in_history=True) + self._call_hooks("after_train_iter", output=(logits, label, loss)) + + self._cur_step += 1 + + if display_progress: + if "step_metrics" in self.states: + progress.set_postfix(**self.states["step_metrics"]) + + # stop when max iter is reached + if self._exceed_max_step(): + break + + self._call_timer(action="stop", item="Train-epoch", keep_in_history=True) + self._call_hooks("after_train_epoch") + self._call_timer(action="reset", item="Train-epoch") + + def _eval( + self, + test_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + # switch engine status + self._engine.eval() + + data_iter = iter(test_dataloader) + num_steps = len(test_dataloader) + + self._call_hooks("before_test") + # prepare progress bar + progress = range(num_steps) + if display_progress: + desc = "Evaluation" + if epoch is not None: + desc = "[Epoch %d / Test]" % epoch + progress = tqdm(progress, desc=desc) + + self._call_hooks("before_test_epoch") + self._call_timer(action="start", item="Test-epoch") + with torch.no_grad(): + for _ in progress: + self._call_hooks("before_test_iter") + self._call_timer(action="start", item="Test-step") + logits, label, loss = self.engine.execute_schedule( + data_iter, + forward_only=True, + return_loss=True, + return_output_label=return_output_label, + ) + self._call_timer(action="stop", item="Test-step", keep_in_history=True) + self._call_hooks("after_test_iter", output=(logits, label, loss)) + + if display_progress: + if "step_metrics" in self.states: + progress.set_postfix(**self.states["step_metrics"]) + + self._call_timer(action="stop", item="Test-epoch", keep_in_history=True) + self._call_hooks("after_test_epoch") + self._call_hooks("after_test") + self._call_timer(action="reset", item="Test-step") + self._call_timer(action="reset", item="Test-epoch") + + def _exceed_max_step(self): + return self._max_steps is not None and self._cur_step >= self._max_steps + + def fit( + self, + train_dataloader: DataLoader, + epochs: int, + max_steps: int = None, + test_dataloader: DataLoader = None, + test_interval: int = 1, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + r"""Trains the model to fit training data. + + Args: + train_dataloader (:class:`torch.utils.data.DataLoader`): DataLoader for training. + epochs (int): Maximum number of epochs. + max_steps (int, optional): Maximum number of running iterations. + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): DataLoader for validation. + test_interval (int, optional): Interval of validation + hooks (list[BaseHook], optional): A list of hooks used in training. + display_progress (bool, optional): If True, a progress bar will be displayed. + """ + + # set epochs and steps, consider gradient accumulation + self._steps_per_epoch = len(train_dataloader) + self._max_steps = max_steps + self._max_epochs = epochs + + # check if testing is required + should_test = False + if test_dataloader is not None: + should_test = True + + display_progress = self._should_display_progress(display_progress) + + # reset hooks + self._reset_states() + if hooks is not None: + assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" + + for hook in hooks: + assert isinstance(hook, BaseHook), \ + f'expected the hook to be of type BaseHook, but got {type(hook)}' + else: + hooks = [] + self.hooks = hooks + self.hooks.sort(key=lambda hook: hook.priority) + if self._verbose: + for hook in self.hooks: + self._logger.info( + f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", + ranks=[0], + ) + self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) + self._call_hooks("after_hook_is_attached") + + self._engine.train() + self._call_hooks("before_train") + + # recover step value if resuming training + last_epoch = self._cur_epoch + if self.cur_epoch != 0: + self._set_current_step(last_epoch) + + for epoch in range(last_epoch, epochs): + # train for one epoch + self._train_epoch( + train_dataloader=train_dataloader, + epoch=epoch, + display_progress=display_progress, + return_output_label=return_output_label, + ) + + # start eval + if should_test and epoch % test_interval == 0: + self._eval( + test_dataloader=test_dataloader, + display_progress=display_progress, + epoch=epoch, + return_output_label=return_output_label, + ) + + self._cur_epoch += 1 + + # check for termination + if self._exceed_max_step(): + self._logger.info( + f"Max number of steps {max_steps} has been reached, training is stopped automatically", + ranks=[0], + ) + break + self._call_hooks("after_train") + self._call_timer("reset", "Train-epoch") + + def evaluate( + self, + test_dataloader: DataLoader, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + """Evaluates the model with testing data. + + Args: + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. + hooks (list, optional): A list of hooks used in evaluation. Defaults to None. + display_progress (bool, optional): If True, the evaluation progress will be printed. Defaults to False. + return_output_label (bool, optional): If True, the output of model and the label + will be returned. Defaults to True. + """ + # set display + display_progress = self._should_display_progress(display_progress) + + # reset hooks + self._reset_states() + if hooks is not None: + assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" + else: + hooks = [] + self.hooks = hooks + self.hooks.sort(key=lambda hook: hook.priority) + if self._verbose: + for hook in self.hooks: + self._logger.info( + f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", + ranks=[0], + ) + self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) + self._call_hooks("after_hook_is_attached") + + # eval + self._eval( + test_dataloader=test_dataloader, + display_progress=display_progress, + return_output_label=return_output_label, + ) + + def predict(self, data: Union[Any, List[Any]]): + """Uses trained model to make a prediction for a tensor or a tensor list. + + Args: + data (Union[:class:`torch.tensor`, List[:class:`torch.tensor`]]): Data as the input. + + Returns: + :class:`torch.tensor`: The output of model as the prediction + """ + # predict without labels + self._engine.eval() + + # prepare a list of (data, label) to make it iterable + # for compatibility with schedule + simple_dataloader = [(data, None)] + data_iter = iter(simple_dataloader) + output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False) + return output diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/trainer/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d36093833d99429b15fb35962c930646b5cbf64 --- /dev/null +++ b/colossalai/trainer/hooks/__init__.py @@ -0,0 +1,12 @@ +from ._base_hook import BaseHook +from ._checkpoint_hook import SaveCheckpointHook +from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook, + TensorboardHook) +from ._lr_scheduler_hook import LRSchedulerHook +from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook + +__all__ = [ + 'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook', + 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook', + 'SaveCheckpointHook' +] diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/trainer/hooks/_base_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..cca8e081ec883b8b5f3d88633ec9f57cc9fd6dfc --- /dev/null +++ b/colossalai/trainer/hooks/_base_hook.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC + +from torch import Tensor + + +class BaseHook(ABC): + """This class allows users to add desired actions in specific time points + during training or evaluation. + + :param priority: Priority in the printing, hooks with small priority will be printed in front + :type priority: int + """ + + def __init__(self, priority: int) -> None: + self.priority = priority + + def after_hook_is_attached(self, trainer): + """Actions after hooks are attached to trainer. + """ + pass + + def before_train(self, trainer): + """Actions before training. + """ + pass + + def after_train(self, trainer): + """Actions after training. + """ + pass + + def before_train_iter(self, trainer): + """Actions before running a training iteration. + """ + pass + + def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + """Actions after running a training iteration. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook. + output (:class:`torch.Tensor`): Output of the model. + label (:class:`torch.Tensor`): Labels of the input data. + loss (:class:`torch.Tensor`): Loss between the output and input data. + """ + pass + + def before_train_epoch(self, trainer): + """Actions before starting a training epoch. + """ + pass + + def after_train_epoch(self, trainer): + """Actions after finishing a training epoch. + """ + pass + + def before_test(self, trainer): + """Actions before evaluation. + """ + pass + + def after_test(self, trainer): + """Actions after evaluation. + """ + pass + + def before_test_epoch(self, trainer): + """Actions before starting a testing epoch. + """ + pass + + def after_test_epoch(self, trainer): + """Actions after finishing a testing epoch. + """ + pass + + def before_test_iter(self, trainer): + """Actions before running a testing iteration. + """ + pass + + def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + """Actions after running a testing iteration. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook + output (:class:`torch.Tensor`): Output of the model + label (:class:`torch.Tensor`): Labels of the input data + loss (:class:`torch.Tensor`): Loss between the output and input data + """ + pass + + def init_runner_states(self, trainer, key, val): + """Initializes trainer's state. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook + key: Key of state to be reset + val: Value of state to be reset + """ + if key not in trainer.states: + trainer.states[key] = val diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcb32cd2dcbc46a9e57dfec1f72abe9bd4aabda --- /dev/null +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import torch +from colossalai.logging import get_dist_logger + +from colossalai.registry import HOOKS +from colossalai.trainer.hooks import BaseHook +from colossalai.utils.checkpointing import save_checkpoint +from ._lr_scheduler_hook import LRSchedulerHook + + +@HOOKS.register_module +class SaveCheckpointHook(BaseHook): + """Saves the model by interval in training process. + + Args: + interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1. + if save_by_iter is True, this arg refers to the number of iters between saving. + checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None. + model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing, + 'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some + unexpected bugs, especially when using **DDP**. + save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, + interval: int = 1, + checkpoint_dir: str = None, + model: torch.nn.Module = None, + save_by_iter: bool = False, + priority: int = 10): + super().__init__(priority=priority) + self.interval = interval + self.checkpoint_dir = checkpoint_dir + self.model = model + self.save_by_iter = save_by_iter + self.logger = get_dist_logger() + + # get lr scheduler from the LRSchedulerHook before train + self._lr_scheduler = None + + def after_hook_is_attached(self, trainer): + # get lr scheduler if exists + for hook in trainer.hooks: + if isinstance(hook, LRSchedulerHook): + self._lr_scheduler = hook.lr_scheduler + break + self.model = self.model if self.model is not None else trainer.engine.model + + def after_train_iter(self, trainer, output, label, loss): + """Saves the model after a training iter. + """ + # save by interval + if self.save_by_iter and trainer.cur_step % self.interval == 0: + save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, + self._lr_scheduler) + self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', + ranks=[0]) + else: + pass + + def after_train_epoch(self, trainer): + """Saves the model after a training epoch. + """ + # save by interval + if trainer.cur_epoch % self.interval == 0: + save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, + self._lr_scheduler) + self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) diff --git a/colossalai/trainer/hooks/_commons_.py b/colossalai/trainer/hooks/_commons_.py new file mode 100644 index 0000000000000000000000000000000000000000..4923b8cba6c04e482bd4f5163b33767d34e83ebb --- /dev/null +++ b/colossalai/trainer/hooks/_commons_.py @@ -0,0 +1,9 @@ +import torch + + +def _format_number(val, prec=5): + if isinstance(val, float): + return f'{val:.{prec}g}' + elif torch.is_tensor(val) and torch.is_floating_point(val): + return f'{val.item():.{prec}g}' + return val diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1f33983422b11389927190acbef90165a15e2a --- /dev/null +++ b/colossalai/trainer/hooks/_log_hook.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +import os.path as osp + +from typing import List +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import HOOKS +from colossalai.logging import DistributedLogger +from colossalai.utils import report_memory_usage, is_dp_rank_0, \ + is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer +from ._base_hook import BaseHook +from ._commons_ import _format_number +from colossalai.trainer.hooks._metric_hook import ThroughputMetric + + +class LogByEpochHook(BaseHook): + """Hook to log by epoch. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, logger, interval: int = 1, priority: int = 1): + super().__init__(priority) + self.logger = logger + self._interval = interval + + def _is_epoch_to_log(self, trainer): + return trainer.cur_epoch % self._interval == 0 + + +@HOOKS.register_module +class LogMetricByStepHook(BaseHook): + """Hook to log metric by step. + + Args: + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, priority: int = 10): + super().__init__(priority) + + def after_train_iter(self, trainer, *args): + trainer.states['step_metrics'] = dict() + for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): + if isinstance(metric_calculator, ThroughputMetric): + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + else: + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + + def after_test_iter(self, trainer, *args): + trainer.states['step_metrics'] = dict() + for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): + if isinstance(metric_calculator, ThroughputMetric): + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + else: + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + + +@HOOKS.register_module +class LogMetricByEpochHook(LogByEpochHook): + """Specialized hook to record the metric to log. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, logger, interval: int = 1, priority: int = 10) -> None: + super().__init__(logger, interval, priority) + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + + def _get_str(self, trainer, mode): + msg = [] + for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}') + msg = ' | '.join(msg) + return msg + + def after_train_epoch(self, trainer): + if self._is_epoch_to_log(trainer): + msg = self._get_str(trainer=trainer, mode='train') + + if self._is_rank_to_log: + self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}') + # f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + + def after_test_epoch(self, trainer): + if self._is_epoch_to_log(trainer): + msg = self._get_str(trainer=trainer, mode='test') + if self._is_rank_to_log: + self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + # f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + + +@HOOKS.register_module +class TensorboardHook(BaseHook): + """Specialized hook to record the metric to Tensorboard. + + Args: + log_dir (str): Directory of log. + ranks (list): Ranks of processors. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer, + defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + log_dir: str, + ranks: List = None, + parallel_mode: ParallelMode = ParallelMode.GLOBAL, + priority: int = 10, + ) -> None: + super().__init__(priority=priority) + from torch.utils.tensorboard import SummaryWriter + + # create log dir + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + os.makedirs(log_dir, exist_ok=True) + + # determine the ranks to generate tensorboard logs + self._is_valid_rank_to_log = False + if not gpc.is_initialized(parallel_mode): + self._is_valid_rank_to_log = True + else: + local_rank = gpc.get_local_rank(parallel_mode) + + if ranks is None or local_rank in ranks: + self._is_valid_rank_to_log = True + + # check for + if gpc.is_initialized(ParallelMode.PIPELINE) and \ + not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log: + raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group") + + if self._is_valid_rank_to_log: + # create workspace on only one rank + if gpc.is_initialized(parallel_mode): + rank = gpc.get_local_rank(parallel_mode) + else: + rank = 0 + + # create workspace + log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}') + os.makedirs(log_dir, exist_ok=True) + + self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}') + + def _log_by_iter(self, trainer, mode: str): + for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + if metric_calculator.epoch_only: + continue + val = metric_calculator.get_last_step_value() + + if self._is_valid_rank_to_log: + self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + + def _log_by_epoch(self, trainer, mode: str): + for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + if metric_calculator.epoch_only: + val = metric_calculator.get_accumulated_value() + if self._is_valid_rank_to_log: + self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + + def after_test_iter(self, trainer, *args): + self._log_by_iter(trainer, mode='test') + + def after_test_epoch(self, trainer): + self._log_by_epoch(trainer, mode='test') + + def after_train_iter(self, trainer, *args): + self._log_by_iter(trainer, mode='train') + + def after_train_epoch(self, trainer): + self._log_by_epoch(trainer, mode='train') + + +@HOOKS.register_module +class LogTimingByEpochHook(LogByEpochHook): + """Specialized hook to write timing record to log. + + Args: + timer (:class:`colossalai.utils.MultiTimer`): Timer for the hook. + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + log_eval (bool, optional): Whether writes in evaluation, defaults to True. + ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0. + """ + + def __init__(self, + timer: MultiTimer, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + ignore_num_train_steps: int = 0) -> None: + super().__init__(logger=logger, interval=interval, priority=priority) + self._timer = timer + self._log_eval = log_eval + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + + # extra handling to avoid the unstable readings of the first + # few training steps to affect the history mean time + self._ignore_num_train_steps = ignore_num_train_steps + self._is_train_step_history_trimmed = False + + def _get_message(self, mode): + msg = [] + for timer_name, timer in self._timer: + if timer_name.startswith(mode): + last_elapsed_time = timer.get_elapsed_time() + if timer.has_history: + if timer_name == 'Train-step' and not self._is_train_step_history_trimmed: + timer._history = timer._history[self._ignore_num_train_steps:] + self._is_train_step_history_trimmed = True + history_mean = timer.get_history_mean() + history_sum = timer.get_history_sum() + msg.append( + f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s' + ) + else: + msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s') + + msg = ' | '.join(msg) + return msg + + def after_train_epoch(self, trainer): + """Writes log after finishing a training epoch. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + msg = self._get_message('Train') + self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}') + + def after_test_epoch(self, trainer): + """Writes log after finishing a testing epoch. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: + msg = self._get_message('Test') + self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + + +@HOOKS.register_module +class LogMemoryByEpochHook(LogByEpochHook): + """Specialized Hook to write memory usage record to log. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + log_eval (bool, optional): Whether writes in evaluation, defaults to True. + """ + + def __init__( + self, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + report_cpu: bool = False, # no reference + ) -> None: + super().__init__(logger=logger, interval=interval, priority=priority) + self._log_eval = log_eval + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() + + def before_train(self, trainer): + """Resets before training. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + report_memory_usage('Before-train', self.logger) + + def after_train_epoch(self, trainer): + """Writes log after finishing a training epoch. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger) + + def after_test(self, trainer): + """Reports after testing. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: + report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger) diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..c6da33442dc39b78474fb36c50b3a9bbfc790666 --- /dev/null +++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py @@ -0,0 +1,47 @@ +from colossalai.registry import HOOKS +from torch import Tensor + +from ._metric_hook import LearningRateMetric, MetricHook + + +@HOOKS.register_module +class LRSchedulerHook(MetricHook): + r"""Build LR scheduler for trainer. + + Args: + lr_scheduler (:class:`colossalai.nn.lr_scheduler`): The specific LR scheduler + in range of ``colossalai.nn.lr_scheduler``, more details about ``lr_scheduler`` could be found in + `lr_scheduler `_. + by_epoch (bool): If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. + store_lr_in_state (bool, optional): If `True`, store the learning rate in each state, defaults to `True`. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + lr_scheduler, + by_epoch: bool, + store_lr_in_state: bool = True, + priority: int = 1, + ): + super().__init__(priority=priority) + self.by_epoch = by_epoch + self.lr_scheduler = lr_scheduler + self.store_lr_in_state = store_lr_in_state + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch, + initial_lr=self.lr_scheduler.get_last_lr()[0]) + + def after_train_epoch(self, trainer): + if self.by_epoch: + self.lr_scheduler.step() + trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + + def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + if not self.by_epoch: + self.lr_scheduler.step() + trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..526d6c746ec6511c97b283ac1074340daabb2516 --- /dev/null +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Callable + +import torch +import torch.distributed as dist +from colossalai.communication import all_reduce +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import HOOKS +from colossalai.utils import get_current_device, is_no_pp_or_last_stage + +from ._base_hook import BaseHook +from ._commons_ import _format_number + + +class Metric(ABC): + """A basic class of metric collectors. It collects a specific + metric during training or evaluation and would always be used with + :class:`MetricHook` to help it update its states and show the + metric. So please use corresponding hook class to make the metric + collector works. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + """ + + def __init__(self, epoch_only: bool): + # is the metric only read for the full epoch + self._epoch_only = epoch_only + + @property + def epoch_only(self): + """Returns :attr:`epoch_only`. + """ + return self._epoch_only + + @abstractmethod + def reset(self) -> None: + """Resets the metric to it's initial state. + By default, this is called at the start of each epoch. + """ + pass + + @abstractmethod + def update(self, *args, **kwargs) -> None: + """Updates the metric's state using the passed batch output. + By default, this is called once for each batch. + """ + pass + + @abstractmethod + def get_last_step_value(self) -> float: + """Returns the metric value in the last iteration. + """ + pass + + @abstractmethod + def get_accumulated_value(self): + """Computes the metric based on it's accumulated state. + By default, this is called at the end of each epoch. + + :return: the actual quantity of interest + :rtype: Any + """ + pass + + @staticmethod + @abstractmethod + def is_better(a, b) -> bool: + """Compares a and b, and returns whether a is better than b + + :return: The result of comparison + :rtype: bool + """ + pass + + +class LossMetric(Metric): + """A metric collector for loss. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + """ + + def __init__(self, epoch_only): + super().__init__(epoch_only=epoch_only) + self.last_step_loss = torch.zeros(1, device=get_current_device()) + self.accum_loss = torch.zeros(1, device=get_current_device()) + self.count = 0 + + def reset(self) -> None: + """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero. + """ + self.last_step_loss.zero_() + self.accum_loss.zero_() + self.count = 0 + + def update(self, loss) -> None: + """Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss. + It expects the output has loss. + + Args: + loss (:class:`torch.tensor`): Current loss of the output. + """ + # expect output to be logits, label and loss + loss_ = loss.detach() + self.last_step_loss.copy_(loss_) + self.accum_loss.add_(loss_) + self.count += 1 + + def get_accumulated_value(self): + """Returns accumulated loss. + """ + if gpc.is_initialized(ParallelMode.DATA): + dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) + self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) + + self.accum_loss.div_(self.count) + return self.accum_loss.item() + + def get_last_step_value(self) -> float: + """Returns :attr:`last_step_loss`. + """ + return self.last_step_loss.cpu().item() + + @staticmethod + def is_better(a, b): + return a < b + + +class LearningRateMetric(Metric): + """A metric collector for learning rate. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + initial_lr (float, optional): Initial learning rate, defaults to 0.0. + """ + + def __init__(self, epoch_only: bool, initial_lr: float = 0.): + super().__init__(epoch_only=epoch_only) + self.lr = initial_lr + + def reset(self) -> None: + pass + + def update(self, lr) -> None: + self.lr = lr + + def get_last_step_value(self) -> float: + return self.lr + + def get_accumulated_value(self): + return self.lr + + @staticmethod + def is_better(a, b) -> bool: + pass + + +class AccuracyMetric(Metric): + """A metric collector for accuracy. It only works for classification + tasks. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task. + """ + + def __init__(self, epoch_only: bool, accuracy_func: Callable): + super().__init__(epoch_only=epoch_only) + self.acc = accuracy_func + self.last_step_sum = torch.zeros(1, device=get_current_device()) + self.last_step_correct = torch.zeros(1, device=get_current_device()) + self.accumulated_sum = torch.zeros(1, device=get_current_device()) + self.accumulated_correct = torch.zeros(1, device=get_current_device()) + + def reset(self) -> None: + self.last_step_sum.zero_() + self.last_step_correct.zero_() + self.accumulated_sum.zero_() + self.accumulated_correct.zero_() + + def update(self, logits, targets, batch_size) -> None: + """Updates last step accuracy and accumulated accuracy with current logits + and labels. It expects the output has logits and labels. + + Args: + logits (:class:`torch.tensor`): The logits output of the model. + targets (:class:`torch.tensor`): Real labels of the dataset. + batch_size (int): Batch size of the task. + """ + if isinstance(logits, (list, tuple)): + logits = logits[0] + if isinstance(targets, (list, tuple)): + targets = targets[0] + # update + correct = self.acc(logits, targets) + + self.last_step_sum.fill_(batch_size) + self.last_step_correct.fill_(correct) + self.accumulated_sum += self.last_step_sum + self.accumulated_correct += self.last_step_correct + + def get_last_step_value(self) -> float: + self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) + self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) + return _format_number((self.last_step_correct / self.last_step_sum).cpu().item()) + + def get_accumulated_value(self): + self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) + self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA) + return (self.accumulated_correct / self.accumulated_sum).item() + + @staticmethod + def is_better(a, b) -> bool: + return a > b + + +class MetricHook(BaseHook): + """Specialized hook classes for :class:`Metric`. + Some help metric collectors initialize, reset and + update their states. Others are used to display and + record the metric. + + Args: + priority (int): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + priority: int, + ): + super().__init__(priority) + self._is_stage_to_compute = is_no_pp_or_last_stage() + + def _check_metric_states_initialization(self, trainer): + if 'metrics' not in trainer.states: + self.init_runner_states(trainer, 'metrics', dict(train={}, test={})) + + +@HOOKS.register_module +class LossHook(MetricHook): + """Specialized hook class for :class:`Loss`. + + Args: + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 0. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, priority: int = 0): + super().__init__(priority) + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + + if self._is_stage_to_compute: + self.train_loss = LossMetric(epoch_only=False) + self.test_loss = LossMetric(epoch_only=True) + + # register the metric calculator + trainer.states['metrics']['train']['Loss'] = self.train_loss + trainer.states['metrics']['test']['Loss'] = self.test_loss + + def before_train_epoch(self, trainer): + if self._is_stage_to_compute: + self.train_loss.reset() + + def after_train_iter(self, trainer, logits, label, loss): + if self._is_stage_to_compute: + self.train_loss.update(loss) + + def before_test_epoch(self, trainer): + if self._is_stage_to_compute: + self.test_loss.reset() + + def after_test_iter(self, trainer, logits, label, loss): + if self._is_stage_to_compute: + self.test_loss.update(loss) + + +@HOOKS.register_module +class AccuracyHook(MetricHook): + """Specialized hook class for :class:`Accuracy`. + + Args: + accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 0. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, accuracy_func: Callable, priority: int = 0): + super().__init__(priority) + self.accuracy_func = accuracy_func + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + if self._is_stage_to_compute: + self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func) + + # register the metric + trainer.states['metrics']['test']['Accuracy'] = self.metric + + def before_test(self, trainer): + if self._is_stage_to_compute: + self.metric.reset() + + def after_test_iter(self, trainer, logits, targets, *args): + if self._is_stage_to_compute: + batch_size = trainer.engine.schedule.batch_size + self.metric.update(logits, targets, batch_size) + + +class ThroughputMetric(Metric): + """Metric for :class:`Throughput`. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + """ + + def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0, use_local: bool = False): + super().__init__(epoch_only=epoch_only) + self.ignored_steps = ignored_steps + self.cur_steps = 0 + self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) + self.accumulated_used_time = torch.zeros(1, device=get_current_device()) + self.last_step_num_samples = torch.zeros(1, device=get_current_device()) + self.last_step_used_time = torch.zeros(1, device=get_current_device()) + self._tflop_per_step = tflop_per_step + self._use_local = use_local + + def reset(self) -> None: + # self.cur_steps = 0 + self.accumulated_num_samples.zero_() + self.accumulated_used_time.zero_() + self.last_step_num_samples.zero_() + self.last_step_used_time.zero_() + + def update(self, num_samples, time) -> None: + self.cur_steps += 1 + self.last_step_num_samples.fill_(num_samples) + self.last_step_used_time.fill_(time) + if self.cur_steps >= self.ignored_steps: + self.accumulated_num_samples += self.last_step_num_samples + self.accumulated_used_time += self.last_step_used_time + + def get_last_step_value(self) -> float: + if self._use_local: + self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) + else: + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) + + sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) + return sample_per_sec + + def get_last_step_info(self) -> str: + if self._use_local: + self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) + else: + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) + + sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) + if self._tflop_per_step > 0: + tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12)) + return f"{sample_per_sec} sample_per_sec, {tflops} Tflops" + else: + return f"{sample_per_sec} sample_per_sec" + + def get_accumulated_value(self) -> float: + self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) + return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() + + @staticmethod + def is_better(a, b) -> bool: + pass + + +@HOOKS.register_module +class ThroughputHook(MetricHook): + """Specialized hook class for :class:`Throughput`. Hook to measure execution throughput (samples/sec). + + Args: + ignored_steps (int, optional): the number of initial training steps to ignore. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + tflop_per_step(int, optional): tera floating point operations per step. + use_local (bool, optional): Whether to use local time for throughput calculation. + """ + + def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False): + super().__init__(priority) + self.ignored_steps = ignored_steps + self._tflop_per_step = tflop_per_step + self._use_local = use_local + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + if self._is_stage_to_compute: + self.metric = ThroughputMetric(epoch_only=True, + ignored_steps=self.ignored_steps, + tflop_per_step=self._tflop_per_step, + use_local=self._use_local) + + # register the metric + trainer.states['metrics']['train']['Throughput'] = self.metric + trainer.states['metrics']['test']['Throughput'] = self.metric + + def before_train_epoch(self, trainer): + if self._is_stage_to_compute: + self.metric.reset() + + def after_train_iter(self, trainer, *args): + if self._is_stage_to_compute: + self.metric.update(trainer.engine.schedule.batch_size, + trainer._timer.get_timer('Train-step').get_elapsed_time()) + + def before_test(self, trainer): + if self._is_stage_to_compute: + self.metric.reset() + + def after_test_iter(self, trainer, *args): + if self._is_stage_to_compute: + self.metric.update(trainer.engine.schedule.batch_size, + trainer._timer.get_timer('Test-step').get_elapsed_time()) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..875b5a93ba4f3a244b4a2a8332a43b84d316707d --- /dev/null +++ b/colossalai/utils/__init__.py @@ -0,0 +1,53 @@ +from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize +from .activation_checkpoint import checkpoint +from .checkpointing import load_checkpoint, save_checkpoint +from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, + ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, + is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, + param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, + sync_model_param, disposable) +from .data_sampler import DataParallelSampler, get_dataloader +from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, + colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) +from .timer import MultiTimer, Timer +from .tensor_detector import TensorDetector + +__all__ = [ + 'checkpoint', + 'free_port', + 'print_rank_0', + 'sync_model_param', + 'is_dp_rank_0', + 'is_tp_rank_0', + 'is_no_pp_or_last_stage', + 'is_using_ddp', + 'is_using_pp', + 'is_using_sequence', + 'conditional_context', + 'is_model_parallel_parameter', + 'clip_grad_norm_fp32', + 'count_zeros_fp32', + 'copy_tensor_parallel_attributes', + 'param_is_not_tensor_parallel_duplicate', + 'get_current_device', + 'synchronize', + 'empty_cache', + 'set_to_cuda', + 'report_memory_usage', + 'colo_device_memory_capacity', + 'colo_device_memory_used', + 'colo_set_process_memory_fraction', + 'Timer', + 'MultiTimer', + 'multi_tensor_applier', + 'DataParallelSampler', + 'get_dataloader', + 'switch_virtual_pipeline_parallel_rank', + 'TensorDetector', + 'load_checkpoint', + 'save_checkpoint', + 'ensure_path_exists', + 'disposable', + 'colo_set_cpu_memory_capacity', + 'colo_get_cpu_memory_capacity', +] diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9ed827a8a7fa649d5689d15a16050c9bc877b6 --- /dev/null +++ b/colossalai/utils/activation_checkpoint.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch.utils.checkpoint import check_backward_validity, detach_variable + +from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states +from .cuda import get_current_device + +import weakref + + +def copy_to_device(obj, device): + if torch.is_tensor(obj): + # Notice: + # When in no_grad context, requires_gard is False after movement + ret = obj.to(device).detach() + ret.requires_grad = obj.requires_grad + return ret + elif isinstance(obj, list): + return [copy_to_device(i, device) for i in obj] + elif isinstance(obj, tuple): + return tuple([copy_to_device(v, device) for v in obj]) + elif isinstance(obj, dict): + return {k: copy_to_device(v, device) for k, v in obj.items()} + else: + return obj + + +class CheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, activation_offload=False, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.activation_offload = activation_offload + ctx.device = get_current_device() + + # preserve rng states + ctx.fwd_cpu_rng_state = torch.get_rng_state() + sync_states() + ctx.fwd_seed_states = get_states(copy=True) + ctx.fwd_current_mode = get_current_mode() + + if hasattr(torch, 'is_autocast_enabled'): + ctx.had_autocast_in_fwd = torch.is_autocast_enabled() + else: + ctx.had_autocast_in_fwd = False + + if activation_offload: + inputs_cuda = copy_to_device(args, ctx.device) + else: + inputs_cuda = args + + with torch.no_grad(): + outputs = run_function(*inputs_cuda) + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + if activation_offload: + tensor_inputs.append(copy_to_device(arg, 'cpu')) + else: + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + if activation_offload: + ctx.tensor_inputs = tensor_inputs + else: + ctx.save_for_backward(*tensor_inputs) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is " + "passed to .backward(). Please use .backward() and do not pass its `inputs` argument.") + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + + if ctx.activation_offload: + tensors = ctx.tensor_inputs + else: + tensors = ctx.saved_tensors + + # store the current states + bwd_cpu_rng_state = torch.get_rng_state() + sync_states() + bwd_seed_states = get_states(copy=True) + bwd_current_mode = get_current_mode() + + # set the states to what it used to be + torch.set_rng_state(ctx.fwd_cpu_rng_state) + for parallel_mode, state in ctx.fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(ctx.fwd_current_mode) + if ctx.activation_offload: + tensors = copy_to_device(tensors, ctx.device) + + # Fill in inputs with appropriate saved tensors. + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + detached_inputs = detach_variable(tuple(inputs)) + if ctx.had_autocast_in_fwd: + with torch.enable_grad(), torch.cuda.amp.autocast(): + outputs = ctx.run_function(*detached_inputs) + else: + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + # recover the rng states + torch.set_rng_state(bwd_cpu_rng_state) + for parallel_mode, state in bwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(bwd_current_mode) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError("none of output has requires_grad=True," + " this checkpoint() is not necessary") + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) + return (None, None) + grads + + +def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): + """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. + + Args: + function: Describe the forward pass function. It should know how to handle the input tuples. + activation_offload: The variable to check whether we should offload activation to cpu + args (list): Tuple containing the parameters of the function + use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there + might be more flexibility for user to define there checkpoint function + + Returns: + Output of running function with provided args. + """ + if use_reentrant: + return CheckpointFunction.apply(function, activation_offload, *args) + else: + return _checkpoint_without_reentrant( + function, + activation_offload, + *args, + ) + + +def _checkpoint_without_reentrant(function, activation_offload=False, *args): + # store rng_state + fwd_cpu_state = torch.get_rng_state() + sync_states() + fwd_seed_states = get_states(copy=True) + fwd_current_mode = get_current_mode() + + # check if use autocast + if hasattr(torch, 'is_autocast_enabled'): + has_autocast_in_fwd = torch.is_autocast_enabled() + else: + has_autocast_in_fwd = False + + # using WeakKeyDictionary to store all the activation the first time we call unpack + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + + # class for weakref.ref + class Holder(): + pass + + # return a Holder object for later unpack process + def pack(x): + res = Holder() + weak_holder_list.append(weakref.ref(res)) + return res + + # unpack hook + def unpack(x): + unpack_counter = 0 + + # re-compute all the activation inside the function when we first call unpack + if len(storage) == 0: + + def inner_pack(inner): + nonlocal unpack_counter + unpack_counter += 1 + + # If the holder went out of scope, the SavedVariable is dead and so + # the value will never be read from the storage. Skip filling it. + if weak_holder_list[unpack_counter - 1]() is None: + return + + # Use detach here to ensure we don't keep the temporary autograd + # graph created during the second forward + storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() + return + + def inner_unpack(packed): + raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + + # restore rng state + torch.set_rng_state(fwd_cpu_state) + for parallel_mode, state in fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(fwd_current_mode) + + # reload arg into device if needed + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device=device) + + # rerun forward, the inner_pack will store all the activations in storage + if has_autocast_in_fwd: + with torch.enable_grad(), \ + torch.cuda.amp.autocast(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + else: + with torch.enable_grad(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + + if x not in storage: + raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this.") + + return storage[x] + + # get device if we need to offload the activation + if activation_offload: + device = get_current_device() + + # run function with pack and unpack as saved_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args) + + # offload activation if needed + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device="cpu") + + return output diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1795b4ce36f41d2a09da0c324db4cb1ef21c5e2c --- /dev/null +++ b/colossalai/utils/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .module_checkpoint import save_checkpoint, load_checkpoint + +__all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a109b3702577be896a032b7caa277a423c2aeb66 --- /dev/null +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -0,0 +1,137 @@ +import torch +import torch.distributed as dist +from colossalai.tensor import ColoTensor +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from typing import Optional, Dict + + +def save_checkpoint(path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[ColossalaiOptimizer] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs): + """save_checkpoint + save a model, whose parameters are `ColoTensor`s. + Args: + path (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. + """ + rank = dist.get_rank() + model_state = model.state_dict() + # save the dist context about the tensors in a new dict, while still maintain the original dict. + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + gather_tensor(v) # gather shared tensors to rank0 + # don't recover tensors in rank0, since the dict is only a copy of model + + if rank == 0: + # sanity check + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + assert v.save_ready + assert v.is_replicate() + delattr(v, 'save_ready') + # model saving + save_state = {'epoch': epoch, 'model': model_state} + torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) + + # delete old dicts + del model_state + # synchronize all the processes + dist.barrier() + + if optimizer is not None: + mapping = dict() + optim_state = optimizer.state_dict() + for k, v in optim_state['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + save_state = {'epoch': epoch, 'optim': optim_state} + torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) + # recover colo tensors in rank0 + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + assert hasattr(t, 'save_ready') + t.set_dist_spec(mapping[(k, n)]) + delattr(t, 'save_ready') + + del optim_state + del mapping + dist.barrier() + + +def load_checkpoint(path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[ColossalaiOptimizer] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + torch_load_kwargs: Optional[Dict] = None, + load_state_dict_kwargs: Optional[Dict] = None): + """load_checkpoint + load a model, whose parameters are `ColoTensor`s. + Args: + path (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. + torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function + load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function + """ + # initialize the default paramters + if not torch_load_kwargs: + torch_load_kwargs = dict() + if not load_state_dict_kwargs: + load_state_dict_kwargs = dict() + + rank = dist.get_rank() + mapping = dict() + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + mapping[n] = p.dist_spec + gather_tensor(p) + + if rank == 0: + load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs) + model.load_state_dict(load_state['model'], **load_state_dict_kwargs) + dist.barrier() + + # scatter loaded parameters + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + scatter_tensor(p, mapping[n]) + if rank == 0: + assert hasattr(p, 'save_ready') + delattr(p, 'save_ready') + del mapping + + if optimizer is not None: + mapping = dict() + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs) + optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs) + dist.barrier() + + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + scatter_tensor(t, mapping[(k, n)]) + + del mapping diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a643a6e10dff9e065975f2625f246c80ebb03417 --- /dev/null +++ b/colossalai/utils/checkpoint/utils.py @@ -0,0 +1,63 @@ +import torch +import torch.distributed as dist +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern + + +def robust_broadcast(tensor): + with torch.no_grad(): + is_cpu_ten = tensor.device.type == 'cpu' + if is_cpu_ten: + b_data = tensor.cuda() + else: + b_data = tensor + + dist.broadcast(b_data, 0) + + if is_cpu_ten: + tensor.copy_(b_data) + + +def gather_tensor(colo_tensor: ColoTensor) -> None: + """Make colo_tensor replicated when the rank is 0 + """ + if not colo_tensor.is_replicate(): + pg = colo_tensor.get_process_group() + # for the group which contains rank 0 + if pg.dp_local_rank() == 0: + old_dist_spec = colo_tensor.dist_spec + colo_tensor.to_replicate_() + if dist.get_rank() != 0: + colo_tensor.set_dist_spec(old_dist_spec) + + # synchronize all processes for unexpected problems + dist.barrier() + + if dist.get_rank() == 0: + setattr(colo_tensor, 'save_ready', True) # set saving signitrue + + +def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: + """Reversal operation of `gather_tensor`. + """ + if dist_spec.placement == DistPlacementPattern.REPLICATE: + robust_broadcast(colo_tensor.data) + else: + global_size = colo_tensor.size_global() + + if dist.get_rank() == 0: + entire_data = colo_tensor.data + else: + entire_data = torch.empty(global_size, device=colo_tensor.device) + robust_broadcast(entire_data) + + if dist.get_rank() == 0: + colo_tensor.set_dist_spec(dist_spec) + else: + rep_tensor = ColoTensor( + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) + rep_tensor.set_dist_spec(dist_spec) + with torch.no_grad(): + colo_tensor.data.copy_(rep_tensor.data) + # synchronize all processes for unexpected problems + dist.barrier() diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe030866894f7666195f1563f6699ae88aa78a61 --- /dev/null +++ b/colossalai/utils/checkpoint_io/__init__.py @@ -0,0 +1,2 @@ +from .io import load, merge, redist, save +from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..140192c05f12cf4843df36d43c66723469ef6cad --- /dev/null +++ b/colossalai/utils/checkpoint_io/backend.py @@ -0,0 +1,74 @@ +import shutil +import tempfile +from abc import ABC, abstractmethod +from typing import Dict, List, Type + +from .reader import CheckpointReader, DiskCheckpointReader +from .writer import CheckpointWriter, DiskCheckpointWriter + +_backends: Dict[str, Type['CheckpointIOBackend']] = {} + + +def register(name: str): + assert name not in _backends, f'"{name}" is registered' + + def wrapper(cls): + _backends[name] = cls + return cls + + return wrapper + + +def get_backend(name: str) -> 'CheckpointIOBackend': + assert name in _backends, f'Unsupported backend "{name}"' + return _backends[name]() + + +class CheckpointIOBackend(ABC): + + def __init__(self) -> None: + super().__init__() + self.temps: List[str] = [] + + @abstractmethod + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + pass + + @abstractmethod + def get_reader(self, base_name: str) -> CheckpointReader: + pass + + @abstractmethod + def get_temp(self, base_name: str) -> str: + pass + + @abstractmethod + def clean_temp(self) -> None: + pass + + +@register('disk') +class CheckpointDiskIO(CheckpointIOBackend): + + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) + + def get_reader(self, base_name: str) -> CheckpointReader: + return DiskCheckpointReader(base_name) + + def get_temp(self, base_name: str) -> str: + temp_dir_name = tempfile.mkdtemp(dir=base_name) + self.temps.append(temp_dir_name) + return temp_dir_name + + def clean_temp(self) -> None: + for temp_dir_name in self.temps: + shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..2199484741bf5bb934d5c0583dd55f51f0cdbffe --- /dev/null +++ b/colossalai/utils/checkpoint_io/constant.py @@ -0,0 +1,9 @@ +import re + +GLOBAL_META_FILE_NAME = 'global_meta.bin' +MODEL_CKPT_FILE_NAME = 'model.bin' +OPTIM_CKPT_FILE_NAME = 'optim.bin' +META_CKPT_FILE_NAME = 'meta.bin' +OTHER_CKPT_FILE_NAME = 'other.bin' + +CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py new file mode 100644 index 0000000000000000000000000000000000000000..529ceb86829b511d05911fe6103d5ef53ecf325d --- /dev/null +++ b/colossalai/utils/checkpoint_io/convertor.py @@ -0,0 +1,227 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +from torch import Tensor + +from .distributed import merge_param, unmerge_param +from .meta import ParamDistMeta, RedistMeta +from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) + + +class CheckpointConvertor(ABC): + + @abstractmethod + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + pass + + @abstractmethod + def complete(self) -> None: + pass + + +class ModelCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int]) -> None: + super().__init__() + self.param_count = param_count + self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) + + @abstractmethod + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + for k, tensor in state_dict.items(): + self.buffer[k][rank] = tensor + converted_keys = set() + for k, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[k]: + tensors = [] + dist_metas = [] + for rank, tensor in rank_dict.items(): + tensors.append(tensor) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][k]) + self.convert_tensors(k, tensors, dist_metas) + converted_keys.add(k) + for k in converted_keys: + del self.buffer[k] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class ModelCheckpointMerger(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: + super().__init__(param_count) + self.sharder = ModelCheckpointSharder(max_shard_size) + self.save_fn = save_fn + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + shard = self.sharder.append(key, tensor) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class ModelCheckpointRedistor(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + redist_meta: RedistMeta) -> None: + super().__init__(param_count) + self.save_fns = save_fns + self.redist_meta = redist_meta + nprocs = len(save_fns) + self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + if len(dist_metas) == 0: + # already global + tensor = tensors[0] + else: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[key][tp_rank][dp_rank]: + shard = self.sharders[rank].append(key, t) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) + + +class OptimizerCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], + paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__() + self.param_count = param_count + self.param_to_os = param_to_os + self.paired_os = paired_os + self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) + self.os_to_param = {v: k for k, v in param_to_os.items()} + + @abstractmethod + def setup(self, param_groups: dict) -> None: + pass + + @abstractmethod + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + self.setup(state_dict['param_groups']) + for idx, state in state_dict['state'].items(): + self.buffer[idx][rank] = state + converted_indices = set() + for idx, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[self.os_to_param[idx]]: + states = [] + dist_metas = [] + for rank, state in rank_dict.items(): + states.append(state) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) + self.convert_states(idx, states, dist_metas) + converted_indices.add(idx) + for idx in converted_indices: + del self.buffer[idx] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fn = save_fn + self.sharder = None + + def setup(self, param_groups: dict) -> None: + if self.sharder is None: + self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(states) + new_state = {} + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) + else: + new_state[state_key] = state_tensor + shard = self.sharder.append(idx, new_state) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], + redist_meta: RedistMeta) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fns = save_fns + self.redist_meta = redist_meta + self.sharders: List[OptimizerCheckpointSharder] = [] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def setup(self, param_groups: dict) -> None: + if len(self.sharders) == 0: + nprocs = len(self.save_fns) + for _ in range(nprocs): + self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + need_merge: bool = True + if len(dist_metas) == 0: + need_merge = False + else: + assert len(dist_metas) == len(states) + new_states = [{} for _ in range(len(self.save_fns))] + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + if need_merge: + tensor = merge_param([state[state_key] for state in states], dist_metas) + else: + tensor = state_tensor + for tp_rank, tensor_list in enumerate( + unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: + new_states[rank][state_key] = t + else: + for new_state in new_states: + new_state[state_key] = state_tensor + for rank, new_state in enumerate(new_states): + shard = self.sharders[rank].append(idx, new_state) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..bf720437c41a9114dc805647e415179a9b1243a9 --- /dev/null +++ b/colossalai/utils/checkpoint_io/distributed.py @@ -0,0 +1,127 @@ +import torch +from numpy import prod +from torch import Tensor +from typing import List, Optional, Tuple +from collections import defaultdict +from .meta import ParamDistMeta, ParamRedistMeta + + +def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' + if not dist_metas[0].used_zero: + # tensors are replicate + return tensors[0] + numel = dist_metas[0].zero_numel + orig_shape = dist_metas[0].zero_orig_shape + tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] + assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' + return torch.cat(tensors).reshape(orig_shape) + + +def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' + for t in tensors[1:]: + assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' + if not dist_metas[0].used_tp: + # tensors are replicate + return tensors[0] + total_parts = prod(dist_meta.tp_num_parts) + assert dist_meta.tp_world_size == total_parts, \ + f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' + shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) + for dim, num_parts in shard_info: + buffer = [] + for start in range(0, len(tensors), num_parts): + buffer.append(torch.cat(tensors[start:start + num_parts], dim)) + tensors = buffer + assert len(tensors) == 1 + return tensors[0] + + +def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) > 0 + # check world size + for dist_meta in dist_metas[1:]: + assert dist_meta.dp_world_size == dist_metas[ + 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' + assert dist_meta.tp_world_size == dist_metas[ + 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' + + +def deduplicate_params(tensors: List[Tensor], + dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: + unique_dist_meta = [] + unique_idx = [] + for i, dist_meta in enumerate(dist_metas): + if dist_meta not in unique_dist_meta: + unique_dist_meta.append(dist_meta) + unique_idx.append(i) + return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] + + +def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + # validate parallel info + validate_parallel_info(dist_metas) + tensors, dist_metas = deduplicate_params(tensors, dist_metas) + unflattened_tensors = [] + # group zero params by tp rank + tensor_dict = defaultdict(list) + dist_meta_dict = defaultdict(list) + for t, dist_meta in zip(tensors, dist_metas): + tensor_dict[dist_meta.tp_rank].append(t) + dist_meta_dict[dist_meta.tp_rank].append(dist_meta) + assert len(tensor_dict + ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' + for tp_rank in tensor_dict.keys(): + unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) + return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) + + +def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_tp: + assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' + return [tensor] + total_parts = prod(redist_meta.tp_num_parts) + assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' + shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) + tensors = [tensor] + for dim, num_parts in shard_info: + buffer = [] + for t in tensors: + assert t.size(dim) % num_parts == 0, \ + f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' + chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] + buffer.extend(chunks) + tensors = buffer + assert len(tensors) == redist_meta.tp_world_size + return tensors + + +def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_zero: + return [tensor] * redist_meta.dp_world_size + tensors: List[Optional[Tensor]] = [ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) + ] + offsets = redist_meta.zero_offsets + [tensor.numel()] + for i, offset in enumerate(offsets[:-1]): + end = offsets[i + 1] + tensors.append(tensor.view(-1)[offset:end]) + if len(tensors) < redist_meta.dp_world_size: + tensors.extend([ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) + for _ in range(redist_meta.dp_world_size - len(tensors)) + ]) + assert len(tensors) == redist_meta.dp_world_size + return tensors + + +def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: + tensors = split_tp_param(tensor, redist_meta) + tensors = [flatten_zero_param(t, redist_meta) for t in tensors] + return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py new file mode 100644 index 0000000000000000000000000000000000000000..f00212cdf85986ff7b3526a2fd4d56210459dda2 --- /dev/null +++ b/colossalai/utils/checkpoint_io/io.py @@ -0,0 +1,170 @@ +import warnings +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import torch.distributed as dist +from torch.nn import Module +from torch.optim import Optimizer + +from .backend import get_backend +from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, + OptimizerCheckpointRedistor) +from .meta import ParamDistMeta, RedistMeta +from .utils import build_checkpoints, optimizer_load_state_dict + + +def save(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk', + **kwargs: Any) -> None: + io_backend = get_backend(backend) + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + if world_size == 1: + # global doesn't need dist_meta + dist_meta = None + else: + assert dist_meta is not None + max_shard_size = int(max_shard_size_gb * 1024**3) + model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, + param_to_os, dist_meta) + writer = io_backend.get_writer(path, overwrite, rank, world_size) + writer.save_others(kwargs) + for model_checkpoint in model_checkpoints: + writer.save_model(model_checkpoint) + for optimizer_checkpoint in optimizer_checkpoints: + writer.save_optimizer(optimizer_checkpoint) + writer.save_meta(meta_checkpoint) + + +def merge(path: str, + output_path: str, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + reader = io_backend.get_reader(path) + if len(reader.meta_list) == 1: + # already global + warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') + return False + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + writer = io_backend.get_writer(output_path, overwrite=overwrite) + writer.save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), + dist_meta_list) + _convert_shards( + OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), + reader.load_optimizers(), dist_meta_list) + meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def redist(path: str, + output_path: str, + redist_meta: RedistMeta, + dist_metas: List[Dict[str, ParamDistMeta]], + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + nprocs = len(dist_metas) + reader = io_backend.get_reader(path) + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + do_redist: bool = False + if len(dist_meta_list) == nprocs: + for a, b in zip(dist_metas, dist_meta_list): + if a != b: + do_redist = True + break + else: + do_redist = True + if not do_redist: + warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') + return False + + writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] + writers[0].save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards( + ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), + reader.load_models(), dist_meta_list) + _convert_shards( + OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, + param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) + for writer, dist_meta in zip(writers, dist_metas): + meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], + dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for shard_dict in shard_generator: + convertor.append(shard_dict, dist_meta_list) + convertor.complete() + + +def load(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + redist_meta: Optional[RedistMeta] = None, + dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, + max_shard_size_gb: float = 0.0, + backend: str = 'disk') -> dict: + is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 + rank: int = dist.get_rank() if dist.is_initialized() else 0 + is_main_process: bool = rank == 0 + # validate args + if redist_meta is None or dist_metas is None: + assert is_global + io_backend = get_backend(backend) + read_path: str = path + if is_main_process: + # pre-process checkpoints + temp_path = io_backend.get_temp(path) + if is_global: + wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) + else: + wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) + if wrote: + read_path = temp_path + if not is_global: + bcast_list = [read_path] if is_main_process else [None] + dist.broadcast_object_list(bcast_list) + read_path = bcast_list[0] + reader = io_backend.get_reader(read_path) + # load model + for shard in reader.load_model(rank): + model.load_state_dict(shard, strict=False) + if optimizer is not None: + for shard in reader.load_optimizer(rank): + # optimizer.load_state_dict(shard) + optimizer_load_state_dict(optimizer, shard) + others_dict = reader.load_others() + if not is_global: + dist.barrier() + # clean up temp + if is_main_process: + io_backend.clean_temp() + return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..994f08b4b5e44be753aa838a0c03ea693a78cc27 --- /dev/null +++ b/colossalai/utils/checkpoint_io/meta.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import List, Optional, Set, Dict + + +@dataclass +class ParamDistMeta: + # parallel info + dp_rank: int + dp_world_size: int + tp_rank: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_numel: Optional[int] = None + zero_orig_shape: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_numel is not None and self.zero_orig_shape is not None + + @property + def parallel_meta(self) -> tuple: + return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size + + @property + def tp_meta(self) -> tuple: + return self.tp_shard_dims, self.tp_num_parts + + @property + def zero_meta(self) -> tuple: + return self.zero_numel, self.zero_orig_shape + + @staticmethod + def from_dict(d: dict) -> 'ParamDistMeta': + return ParamDistMeta(**d) + + +@dataclass +class ParamRedistMeta: + # parallel info + dp_world_size: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_start_dp_rank: Optional[int] = None + zero_offsets: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_start_dp_rank is not None and self.zero_offsets is not None + + +@dataclass +class RankRedistMeta: + dp_rank: int + tp_rank: int + pp_rank: int + + +@dataclass +class PipelineRedistMeta: + params: Set[str] + + +@dataclass +class RedistMeta: + rank_meta: Dict[str, Dict[int, RankRedistMeta]] + pipeline_meta: List[PipelineRedistMeta] + param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..3158c6481263a7a4be4b15425728d2505b9661e2 --- /dev/null +++ b/colossalai/utils/checkpoint_io/reader.py @@ -0,0 +1,131 @@ +import os +from abc import ABC, abstractmethod +from collections import Counter +from typing import Dict, Generator, List, Optional, Tuple + +import torch + +from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME +from .meta import ParamDistMeta +from .utils import is_duplicated_list + + +class CheckpointReader(ABC): + + def __init__(self, base_name: str) -> None: + super().__init__() + self.base_name = base_name + self.meta_list = [] + + @abstractmethod + def read(self, name: str) -> dict: + pass + + @abstractmethod + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + pass + + @abstractmethod + def load_model(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_models(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_others(self) -> dict: + pass + + +class DiskCheckpointReader(CheckpointReader): + + def __init__(self, base_name: str) -> None: + super().__init__(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + global_meta = self.read(GLOBAL_META_FILE_NAME) + for meta_file_name in global_meta['meta']: + meta = self.read(meta_file_name) + if meta.get('dist_meta', None) is None: + # only global checkpoint can have empty dist_meta + assert len(global_meta['meta']) == 1 + self.meta_list.append(meta) + + def read(self, name: str) -> dict: + return torch.load(os.path.join(self.base_name, name)) + + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', + None), meta.get('paired_os', None)) + for meta in self.meta_list] + dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) + # reduce param_count + param_count = Counter(p for params in params_list for p in params) + # validate param_to_os + assert is_duplicated_list(param_to_os_list) + assert is_duplicated_list(paired_os_list) + return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] + + def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: + meta = self.meta_list[rank] + checkpoint_names = meta.get(shard_type, []) + for name in checkpoint_names: + yield self.read(name) + + def load_model(self, rank: int) -> Generator[dict, None, None]: + return self._load_shard('model', rank) + + def load_models(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + model_checkpoint_names = meta.get('model', []) + if indices[i] < len(model_checkpoint_names): + shards[i] = self.read(model_checkpoint_names[indices[i]]) + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + param_groups = None + for shard in self._load_shard('optimizer', rank): + if param_groups is None: + param_groups = shard['param_groups'] + else: + shard['param_groups'] = param_groups + yield shard + + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + param_groups = [] + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + optimizer_checkpoint_names = meta.get('optimizer', []) + if indices[i] < len(optimizer_checkpoint_names): + shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) + if indices[i] == 0: + param_groups.append(shards[i]['param_groups']) + else: + shards[i]['param_groups'] = param_groups[i] + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_others(self) -> dict: + return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..135385f5737947400b8f6ae860f51001ca9742fd --- /dev/null +++ b/colossalai/utils/checkpoint_io/utils.py @@ -0,0 +1,223 @@ +import warnings +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch import Tensor +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.optim import Optimizer + +from .meta import ParamDistMeta + + +def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: + if arg is not None: + return fn(arg) + + +def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: + # ensure all params in optimizer are in model state dict + params_set = set(id(p) for p in model.parameters()) + for group in optimizer.param_groups: + for p in group['params']: + assert id(p) in params_set + param_mappings = {} + start_index = 0 + + def get_group_mapping(group): + nonlocal start_index + param_mappings.update( + {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) + start_index += len(group['params']) + + for g in optimizer.param_groups: + get_group_mapping(g) + return {k: param_mappings[id(p)] for k, p in model.named_parameters()} + + +def compute_optimizer_state_size(state: Dict[str, Any]) -> int: + size = 0 + for v in state.values(): + if isinstance(v, Tensor): + size += v.numel() * v.element_size() + return size + + +class ModelCheckpointSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, Tensor] = {} + self.buffer_size: int = 0 + + def append(self, key: str, tensor: Tensor) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {} + self.buffer_size = 0 + self.buffer[key] = tensor + self.buffer_size += tensor.numel() * tensor.element_size() + return retval + + def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: + shards = [] + for key, tensor in state_dict.items(): + shard = self.append(key, tensor) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer) > 0 else None + + +class OptimizerCheckpointSharder: + + def __init__(self, max_shard_size: int, param_groups: dict) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} + self.buffer_size: int = 0 + self.returned_first: bool = False + + def append(self, key: int, state: dict) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {'state': {}} + self.buffer_size = 0 + self.buffer['state'][key] = state + self.buffer_size += compute_optimizer_state_size(state) + return retval + + def extend(self, state_dict: Dict[str, dict]) -> List[dict]: + shards = [] + for key, state in state_dict['state'].items(): + shard = self.append(key, state) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer['state']) > 0 else None + + +def shard_checkpoint(max_shard_size: int, + model_state_dict: Dict[str, Tensor], + optimizer_state_dict: Optional[dict] = None, + param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: + has_optimizer: bool = False + if optimizer_state_dict is not None: + assert param_to_os is not None + os_to_param = {v: k for k, v in param_to_os.items()} + for os_key in optimizer_state_dict['state'].keys(): + assert os_key in os_to_param + assert os_to_param[os_key] in model_state_dict + has_optimizer = True + model_sharder = ModelCheckpointSharder(max_shard_size) + model_shards = model_sharder.extend(model_state_dict) + run_if_not_none(model_shards.append, model_sharder.complete()) + if not has_optimizer: + return model_shards, [] + optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) + optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) + run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) + return model_shards, optimizer_shards + + +def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: + os_to_param = {v: k for k, v in param_to_os.items()} + paired_os = {} + for idx, state in optimizer_state_dict['state'].items(): + paired_os[idx] = {} + p = model_state_dict[os_to_param[idx]] + for k, v in state.items(): + if isinstance(v, Tensor) and v.shape == p.shape: + paired_os[idx][k] = True + else: + paired_os[idx][k] = False + return paired_os + + +def build_checkpoints(max_size: int, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: + save_global = dist_meta is None + model_state_dict = model.state_dict() + optimizer_state_dict = optimizer.state_dict() if optimizer else None + meta = {'dist_meta': dist_meta} + if optimizer: + param_to_os = param_to_os or get_param_to_os(model, optimizer) + paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) + meta['param_to_os'] = param_to_os + meta['paired_os'] = paired_os + if not save_global and eliminate_replica: + # filter dp replicated params + model_state_dict = { + k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + if optimizer: + optimizer_state_dict['state'] = { + param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] + for k in model_state_dict.keys() + if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + meta['params'] = list(model_state_dict.keys()) + if len(model_state_dict) == 0: + warnings.warn('model state dict is empty, checkpoint is not saved') + return [], [], meta + model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, + param_to_os) + return model_checkpoints, optimizer_checkpoints, meta + + +def is_duplicated_list(list_: List[Any]) -> bool: + if len(list_) == 0: + return True + elem = list_[0] + for x in list_[1:]: + if x != elem: + return False + return True + + +def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: + for k, v in src_state.items(): + if k in dest_state: + old_v = dest_state[k] + if isinstance(old_v, Tensor): + old_v.copy_(v) + else: + dest_state[k] = v + + +def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: + assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] + state_dict = deepcopy(state_dict) + groups = optimizer.param_groups + saved_groups = state_dict['param_groups'] + idx_to_p: Dict[str, Parameter] = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in groups))) + } + missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) + unexpected_keys = [] + error_msgs = [] + for idx, state in state_dict['state'].items(): + if idx in idx_to_p: + old_state = optimizer.state[idx_to_p[idx]] + copy_optimizer_state(state, old_state) + else: + unexpected_keys.append(idx) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, + "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..4552accde470d9e0b54b9d658fdcfb6a33e4bc78 --- /dev/null +++ b/colossalai/utils/checkpoint_io/writer.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Optional +from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME +import torch +import os + + +class CheckpointWriter(ABC): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__() + self.base_name = base_name + self.overwrite = overwrite + self.rank = rank + self.world_size = world_size + self.is_distributed = world_size > 1 + self.is_main_process = rank == 0 + + @abstractmethod + def write(self, name: str, state_dict: dict) -> None: + pass + + @abstractmethod + def save_model(self, model_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_meta(self, meta_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_others(self, kwargs: dict) -> None: + pass + + +class DiskCheckpointWriter(CheckpointWriter): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__(base_name, overwrite, rank, world_size) + if not os.path.exists(base_name): + os.makedirs(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + self.model_checkpoint_names = [] + self.optimizer_checkpoint_names = [] + self.is_meta_saved: bool = False + self._save_global_meta() + + def write(self, name: str, state_dict: dict) -> None: + path = os.path.join(self.base_name, name) + if os.path.exists(path) and not self.overwrite: + raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') + torch.save(state_dict, path) + + def _save_global_meta(self) -> None: + if self.is_main_process: + global_meta = {'meta': []} + if self.is_distributed: + for i in range(self.world_size): + global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) + else: + global_meta['meta'].append(META_CKPT_FILE_NAME) + self.write(GLOBAL_META_FILE_NAME, global_meta) + + def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: + checkpoint_name = base_name + if self.is_distributed: + checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') + if shard_idx is not None: + checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') + return checkpoint_name + + def save_model(self, model_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save model after saving meta' + name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) + self.write(name, model_checkpoint) + self.model_checkpoint_names.append(name) + + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' + name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) + self.write(name, optimizer_checkpoint) + self.optimizer_checkpoint_names.append(name) + + def save_meta(self, meta_checkpoint: dict) -> None: + if len(self.model_checkpoint_names) > 0: + meta_checkpoint['model'] = self.model_checkpoint_names + if len(self.optimizer_checkpoint_names) > 0: + meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names + self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) + self.is_meta_saved = True + + def save_others(self, kwargs: dict) -> None: + if self.is_main_process: + self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c6b6370ede4420191eba75e3ab7816f99ed499 --- /dev/null +++ b/colossalai/utils/checkpointing.py @@ -0,0 +1,266 @@ +from collections import OrderedDict +from itertools import chain + +import torch +import torch.distributed as dist +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.constants import IS_TENSOR_PARALLEL +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +from .common import is_using_pp + +__all__ = ["save_checkpoint", "load_checkpoint"] + + +def broadcast_state_dict(state_dict, parallel_mode): + state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict] + src_rank = gpc.get_ranks_in_group(parallel_mode)[0] + dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) + return state_dict[0] + + +def partition_tensor_parallel_state_dict(state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict()): + src_rank = gpc.get_ranks_in_group(parallel_mode)[0] + depth = gpc.get_world_size(parallel_mode) + group = gpc.get_cpu_group(parallel_mode) + is_rank0 = gpc.get_local_rank(parallel_mode) == 0 + partition_info = [None] + if is_rank0: + partition_info_dict = OrderedDict() + for key, param in state_dict.items(): + dim = dims[key] + is_partitioned = partition_states[key] + shape = list(param.shape) + if is_partitioned: + shape[dim] = shape[dim] // depth + partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim) + partition_info[0] = partition_info_dict + dist.broadcast_object_list(partition_info, src_rank, group=group) + partitioned_state = OrderedDict() + for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items(): + if is_partitioned: + output = torch.empty(shape, dtype=dtype) + if is_rank0: + scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)] + else: + scatter_list = None + dist.scatter(output, scatter_list, src_rank, group=group) + else: + if is_rank0: + output = state_dict[key] + else: + output = torch.empty(shape, dtype=dtype) + dist.broadcast(output, src_rank, group=group) + partitioned_state[key] = output + return partitioned_state + + +def gather_tensor_parallel_state_dict( + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, +): + dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] + depth = gpc.get_world_size(parallel_mode) + + for key in list(state_dict.keys()): + param = state_dict.pop(key) + param = param if keep_vars else param.detach() + dim = dims.get(key, 0) + do_partition = partition_states.get(key, True) + if do_partition: + temp = param.transpose(0, dim).contiguous() + gather_list = None + if gpc.get_local_rank(parallel_mode) == 0: + shape = list(param.shape) + shape[0], shape[dim] = shape[dim], shape[0] + shape[0] *= depth + param = torch.empty(shape, dtype=param.dtype, device=param.device) + gather_list = list(torch.chunk(param, depth, dim=0)) + dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode)) + param = torch.transpose(param, 0, dim) + # update params in state_dict only on local rank 0 + if gpc.get_local_rank(parallel_mode) == 0: + state_dict[key] = param + + return state_dict + + +def _send_state_dict(state_dict, dst, parallel_mode): + state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict) + dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode)) + dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode)) + + +def _recv_state_dict(src, parallel_mode): + state_size = torch.tensor([0], dtype=torch.long) + dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode)) + state_tensor = torch.empty(state_size.item(), dtype=torch.uint8) + dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode)) + state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size) + return state_dict + + +def partition_pipeline_parallel_state_dict(model, state_dict): + pipeline_state = OrderedDict() + + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # receive all states from prev stage + if not gpc.is_first_rank(ParallelMode.PIPELINE): + state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) + # move states to output + for name, _ in model.named_parameters(recurse=True): + if name in state_dict: + pipeline_state[name] = state_dict.pop(name) + for name, _ in model.named_buffers(recurse=True): + if name in state_dict: + pipeline_state[name] = state_dict.pop(name) + for name, _ in model.named_modules(): + extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + pipeline_state[extra_state_key] = state_dict.pop(extra_state_key) + # send rest states to next stage + if not gpc.is_last_rank(ParallelMode.PIPELINE): + _send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) + + return pipeline_state + + +def gather_pipeline_parallel_state_dict(state_dict): + gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) + dist.gather_object( + state_dict, + gathered_states, + dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0], + group=gpc.get_cpu_group(ParallelMode.PIPELINE), + ) + + state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) + + return state_dict + + +def save_checkpoint(file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs): + """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, + lr_scheduler etc. into a checkpoint dictionary. + + Args: + file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a + file name. + epoch (int): Epoch number (indicates how many epochs have you trained this model). + model (:class:`torch.nn.Module`): Model to be saved. + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. + lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional): + lr_scheduler to be saved, defaults to None. + pickle_module: module used for pickling metadata and objects + pickle_protocol: can be specified to override the default protocol + """ + # ckpt container + checkpoint = {"epoch": epoch} + + model_state = model.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + model_state = gather_pipeline_parallel_state_dict(model_state) + + if gpc.get_global_rank() == 0: + checkpoint["model"] = model_state + + # if optimizer is not None: + # checkpoint['optimizer'] = optimizer.state_dict() + + # if lr_scheduler is not None: + # checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + + torch.save(checkpoint, file, **kwargs) + + +def broadcast_model(model: torch.nn.Module): + src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0] + for p in model.parameters(): + if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0: + group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group( + ParallelMode.TENSOR) + dist.broadcast(p, src_rank, group=group) + + +def load_checkpoint( + file, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + strict: bool = True, +): + """Loads training states from a checkpoint file. + + Args: + file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike + object containing a file name. + model (:class:`torch.nn.Module`): Model to load saved weights and buffers. + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. + lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional): + lr_scheduler to recuperate, defaults to None. + strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` + of the checkpoint match the names of parameters and buffers in model, defaults to True. + + Returns: + int: The saved epoch number. + + Raises: + RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated + """ + state_dict = (torch.load(file, map_location=torch.device("cpu")) + if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) + + # model states + model_state = state_dict.pop("model") if state_dict is not None else dict() + # pipeline + if is_using_pp(): + model_state = partition_pipeline_parallel_state_dict(model, model_state) + try: + model.load_state_dict(model_state, strict=strict) + broadcast_model(model) + except RuntimeError as e: + error_msgs = str(e) + if error_msgs.startswith("Error(s) in loading state_dict for "): + error_msgs = error_msgs.split("\n\t")[1:] + dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0] + all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))] + dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) + if gpc.get_global_rank() == 0: + all_error_msgs = list(chain.from_iterable(all_error_msgs)) + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs))) + else: + raise e + + # broadcast the rest states + state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL) + + # # optimizer states + # if optimizer is not None and 'optimizer' in state_dict: + # optimizer.load_state_dict(state_dict['optimizer']) + + # # lr scheduler states + # if lr_scheduler is not None and 'lr_scheduler' in state_dict: + # lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + + # last epoch + last_epoch = state_dict.pop("epoch", -1) + + return last_epoch diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d8cd709b3d4a77bc9a61459a32865e993a6fe6e7 --- /dev/null +++ b/colossalai/utils/common.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import functools +import os +import random +import socket +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import torch +from torch._six import inf +from torch.nn.parameter import Parameter + +try: + import colossalai._C.fused_optim +except: + pass + +from collections import defaultdict +from contextlib import contextmanager + +import torch.distributed as dist + +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.tensor import ColoParameter, ProcessGroup + +from .multi_tensor_apply import multi_tensor_applier + + +def print_rank_0(msg: str, logger=None): + """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. + + Args: + msg (str): A string message to output. + logger (:class:`colossalai.logging.DistributedLogger`, optional): + The logger to record the message, defaults to None. + """ + if gpc.get_global_rank() == 0: + if logger is None: + print(msg, flush=True) + else: + logger.info(msg) + + +def ensure_path_exists(filename: str): + # ensure the path exists + dirpath = os.path.dirname(filename) + if not os.path.exists(dirpath): + Path(dirpath).mkdir(parents=True, exist_ok=True) + + +def free_port(): + while True: + try: + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = random.randint(20000, 65000) + sock.bind(('localhost', port)) + sock.close() + return port + except Exception: + continue + + +def sync_model_param(model, parallel_mode): + r"""Make sure data parameters are consistent during Data Parallel Mode. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + for param in model.parameters(): + ranks = gpc.get_ranks_in_group(parallel_mode) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) + + +def is_dp_rank_0(): + return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) + + +def is_tp_rank_0(): + return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) + + +def is_no_pp_or_last_stage(): + return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) + + +def is_using_ddp(): + return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 + + +def is_using_pp(): + return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 + + +def is_using_sequence(): + return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 + + +@contextmanager +def conditional_context(context_manager, enable=True): + if enable: + with context_manager: + yield + else: + yield + + +class model_branch_context(object): + + def __enter__(self): + self.env_status = env.save() + + def __exit__(self, *exc_info): + env.load(**self.env_status) + + +def is_model_parallel_parameter(p): + return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) + + +def _calc_l2_norm(grads): + norm = 0.0 + if len(grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + colossalai._C.fused_optim.multi_tensor_l2norm, + dummy_overflow_buf, + [grads], + False # no per-parameter norm + ) + return norm + + +def _calc_lp(grads, norm_type): + norm = 0.0 + for grad in grads: + grad_norm = torch.norm(grad, norm_type) + norm += grad_norm**norm_type + return norm + + +def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + if torch.is_tensor(norm) and norm.device.type != 'cuda': + norm = norm.to(torch.cuda.current_device()) + return norm + + +def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: + if isinstance(norm, float): + norm = torch.Tensor([norm]) + if move_to_cuda: + norm = norm.to(torch.cuda.current_device()) + return norm + + +# ======== Gradient Clipping ========= + + +def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + grads = [p.grad for p in params] + use_cuda_kernel = grads[0].device.type == 'cuda' + if norm_type == inf: + local_lp = max([g.abs().max() for g in grads]) + elif norm_type == 2.0 and use_cuda_kernel: + local_lp = _calc_l2_norm(grads)**norm_type + else: + local_lp = _calc_lp(grads, norm_type) + if isinstance(local_lp, torch.Tensor): + return local_lp.item() + return local_lp + + +def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) + for p in params: + if p.is_replicate(): + buckets[None].append(p) + else: + buckets[p.get_process_group().tp_process_group()].append(p) + total_lp = 0.0 + for group, bucket in buckets.items(): + local_lp = _compute_local_lp(bucket, norm_type) + if group is not None: + local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) + else: + dist.all_reduce(local_lp_tensor, group=group) + local_lp = local_lp_tensor.item() + if norm_type == inf: + total_lp = max(total_lp, local_lp) + else: + total_lp += local_lp + return total_lp + + +def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) + else: + dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) + total_lp = total_lp_tensor.item() + return total_lp + + +def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grad_dtype = None + cpu_grad_params: List[ColoParameter] = [] + cuda_grad_params: List[ColoParameter] = [] + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, ColoParameter) + if grad_dtype is None: + grad_dtype = p.grad.dtype + assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' + if p.grad.device.type == 'cuda': + cuda_grad_params.append(p) + else: + cpu_grad_params.append(p) + norm_type = float(norm_type) + cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) + cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) + if norm_type == inf: + total_lp = max(cpu_lp, cuda_lp) + else: + total_lp = cpu_lp + cuda_lp + return _compute_pp_grad_lp(total_lp, norm_type) + + +def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + total_norm = _compute_grad_lp(parameters, norm_type) + if norm_type != inf: + total_norm = total_norm**(1 / norm_type) + return total_norm + + +def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1.0: + cuda_grads: List[torch.Tensor] = [] + cpu_grads: List[torch.Tensor] = [] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + if p.grad.device.type == 'cuda': + cuda_grads.append(p.grad.detach()) + else: + cpu_grads.append(p.grad.detach()) + if len(cuda_grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, + [cuda_grads, cuda_grads], clip_coef) + for g in cpu_grads: + g.mul_(clip_coef) + + +def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: + total_norm = compute_grad_norm(parameters, norm_type) + _clip_grad_norm(parameters, max_norm, total_norm) + return total_norm + + +def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters whose gradients are in fp32. + + This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and + added functionality to handle model parallel parameters. + + Note: + the gradients are modified in place. + + Args: + parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): + An iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (Union[float, int]): Max norm of the gradients. + norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + float: Total norm of the parameters. + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params: List[Parameter] = [] + has_zero_shared_param: bool = False + for param in parameters: + if param.grad is not None: + # Make sure the grads are in fp32 + assert param.grad.dtype == torch.float, \ + f'expected gradient to be dtype torch.float, but got {param.grad.type()}' + if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: + has_zero_shared_param = True + params.append(param) + + if len(params) == 0: + enable_cuda_kernels = False + else: + enable_cuda_kernels = params[0].grad.device.type == 'cuda' + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + + # Parameters can be on CPU or CUDA + # If parameters are on CPU, disable CUDA kernerls + + # Calculate norm. + if norm_type == inf: + total_norm = max(p.grad.data.abs().max() for p in params) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL), + async_op=False) + if has_zero_shared_param: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.DATA), + async_op=False) + total_norm = total_norm_cuda[0].item() + else: + tensor_parallel_grads = [] + no_tensor_parallel_grads = [] + zero_sharded_grads = [] + for p in params: + if is_model_parallel_parameter(p): + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) + tensor_parallel_grads.append(p.grad.data / reductor) + elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: + zero_sharded_grads.append(p.grad.data) + else: + no_tensor_parallel_grads.append(p.grad.data) + + if norm_type == 2.0 and enable_cuda_kernels: + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type + zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type + else: + tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) + no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) + zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) + # If norm is type of float, then we convert them into torch.Tensor. + tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) + zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) + # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors + if not enable_cuda_kernels: + tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) + no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) + zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) + + # Sum across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: + dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + # Sum across all zero sharded GPUs + if len(zero_sharded_grads) > 0: + dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) + no_tensor_parallel_norm += zero_sharded_norm + total_norm = tensor_parallel_norm + no_tensor_parallel_norm + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) + total_norm = total_norm**(1.0 / norm_type) + if torch.is_tensor(total_norm): + total_norm = total_norm.item() + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + if enable_cuda_kernels: + grads = [p.grad.detach() for p in params] + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], + clip_coeff) + else: + for p in params: + p.grad.detach().mul_(clip_coeff) + return total_norm + + +def count_zeros_fp32(parameters): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = 0.0 + for param in parameters: + grad_not_none = param.grad is not None + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() + + # Sum across all model-parallel GPUs. + ops = [] + ops.append( + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) + if gpc.is_initialized(ParallelMode.PIPELINE): + ops.append( + dist.all_reduce(total_num_zeros, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PIPELINE), + async_op=True)) + + for req in ops: + req.wait() + total_num_zeros = total_num_zeros.item() + + return total_num_zeros + + +def copy_tensor_parallel_attributes(src_tensor, dst_tensor): + for attr in TENSOR_PARALLEL_ATTRIBUTES: + if hasattr(src_tensor, attr): + val = getattr(src_tensor, attr) + setattr(dst_tensor, attr, val) + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( + ParallelMode.TENSOR) == 0) + + +@contextmanager +def switch_virtual_pipeline_parallel_rank(rank): + prev_rank = gpc.virtual_pipeline_parallel_rank + try: + gpc.set_virtual_pipeline_parallel_rank(rank) + yield + finally: + gpc.set_virtual_pipeline_parallel_rank(prev_rank) + + +def disposable(func: Callable) -> Callable: + executed = False + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal executed + if not executed: + executed = True + return func(*args, **kwargs) + + return wrapper diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..60f3ccb60883e7af56da5f41e1b18c7e20cc098f --- /dev/null +++ b/colossalai/utils/cuda.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + + +def set_to_cuda(models): + """Send model to gpu. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(get_current_device()) + else: + return models.to(get_current_device()) + + +def get_current_device() -> torch.device: + """ + Returns currently selected device (gpu/cpu). + If cuda available, return gpu, otherwise return cpu. + """ + if torch.cuda.is_available(): + return torch.device(f'cuda:{torch.cuda.current_device()}') + else: + return torch.device('cpu') + + +def synchronize(): + """Similar to cuda.synchronize(). + Waits for all kernels in all streams on a CUDA device to complete. + """ + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def empty_cache(): + """Similar to cuda.empty_cache() + Releases all unoccupied cached memory currently held by the caching allocator. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/colossalai/utils/data_sampler/__init__.py b/colossalai/utils/data_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12798a94c2d063bb120f805967e748c5a1059a3a --- /dev/null +++ b/colossalai/utils/data_sampler/__init__.py @@ -0,0 +1,4 @@ +from .base_sampler import BaseSampler +from .data_parallel_sampler import DataParallelSampler, get_dataloader + +__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader'] diff --git a/colossalai/utils/data_sampler/base_sampler.py b/colossalai/utils/data_sampler/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..89f3bca5b1b51925ef7b32e4a08f1df301776fcb --- /dev/null +++ b/colossalai/utils/data_sampler/base_sampler.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + + +class BaseSampler(ABC): + + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + @abstractmethod + def __len__(self): + pass + + @abstractmethod + def __iter__(self): + pass diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..945dc54b397a3869232827838d4596e95e188059 --- /dev/null +++ b/colossalai/utils/data_sampler/data_parallel_sampler.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# adpated from torch.utils.data.DistributedSampler + +import math +import random +import numpy as np +from typing import TypeVar, Iterator + +import torch +from torch.utils.data import Sampler, Dataset, DataLoader + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import DATA_SAMPLERS + +T_co = TypeVar('T_co', covariant=True) + + +@DATA_SAMPLERS.register_module +class DataParallelSampler(Sampler): + """A data sampler for distributed data parallelism. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. + shuffle (bool, optional): Whether to shuffle data, defaults to False. + seed (int, optional): The random seed used for sampling, defaults to 0. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + """ + + def __init__(self, + dataset: Dataset, + shuffle: bool = False, + seed: int = 0, + drop_last: bool = False) -> None: + self.dataset = dataset + self.num_replicas = gpc.get_world_size(ParallelMode.DATA) + self.rank = gpc.get_local_rank(ParallelMode.DATA) + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + # type: ignore[arg-type] + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + # `type:ignore` is required because Dataset cannot provide a default __len__ + # see NOTE in pytorch/torch/utils/data/sampler.py + (len(self.dataset) - self.num_replicas) / \ + self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil( + len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + # type: ignore[arg-type] + indices = torch.randperm(len(self.dataset), generator=g).tolist() + + # update for next epoch so that there is no need to call + # set_epoch manually + self.epoch += 1 + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / + len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +def get_dataloader(dataset, + shuffle=False, + seed=1024, + add_sampler=True, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) + + Note: + When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data + on the 1st stage and label on the last stage. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + + if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: + sampler = DataParallelSampler(dataset, shuffle=shuffle) + else: + sampler = None + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + if sampler is None: + return DataLoader(dataset, + worker_init_fn=seed_worker, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + else: + return DataLoader(dataset, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..434e90edd3b98fb7f69c502d2d5ebf21e127d0bb --- /dev/null +++ b/colossalai/utils/memory.py @@ -0,0 +1,178 @@ +import torch +import gc +import psutil +from collections import namedtuple + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.utils import get_current_device +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode +from colossalai.logging import get_dist_logger +from packaging import version + +_GLOBAL_CUDA_MEM_FRACTION = 1.0 +_GLOBAL_CPU_MEM_CAPACITY = -1 + + +def _bytes_to_MB(val, decimal=2): + """A byte-to-Megabyte converter, default using binary notation. + + :param val: X bytes to convert + :return: X' MB + """ + return round(val / (1024 * 1024), decimal) + + +# copy from PatrickStar +def _get_cpu_memory_info(): + ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"]) + try: + # psutil reads the memory info from /proc/memory_info, + # which results in returning the host memory instead of + # that of container. + # Here we try to read the container memory with method in: + # https://stackoverflow.com/a/46213331/5163915 + mems = {} + with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f: + for line in f: + fields = line.split() + mems[fields[0]] = int(fields[1]) * 1024 + total = mems[b"MemTotal:"] + free = mems[b"MemFree:"] + cached = mems[b"Cached:"] + buffers = mems[b"Buffers:"] + used = total - free - cached - buffers + if used < 0: + used = total - free + mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used) + except FileNotFoundError: + mems = psutil.virtual_memory() + mem_info = ps_mem_info( + total=mems.total, + free=mems.free, + cached=mems.cached, + buffers=mems.buffers, + used=mems.used, + ) + return mem_info + + +def report_memory_usage(message, logger=None, report_cpu=False): + """Calculate and print RAM usage (in GB) + + Args: + message (str): A prefix message to add in the log. + logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information. + report_cpu (bool, optional): Whether to report CPU memory. + + Raises: + EnvironmentError: Raise error if no distributed environment has been initialized. + """ + if not gpc.is_initialized(ParallelMode.GLOBAL): + raise EnvironmentError("No distributed environment is initialized") + + gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated()) + gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated()) + gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved()) + gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved()) + + full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ + + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" + + if report_cpu: + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports + gc.collect() + vm_stats = psutil.virtual_memory() + vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available) + full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%" + + if logger is None: + logger = get_dist_logger() + logger.info(full_log) + + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats() + + +def colo_device_memory_capacity(device: torch.device) -> int: + """ + Get the capacity of the memory of the device + + Args: + device (torch.device): a device + + Returns: + int: size in byte + """ + assert isinstance(device, torch.device) + if device.type == 'cpu': + # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. + return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node + if device.type == 'cuda': + return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + + +def colo_device_memory_used(device: torch.device) -> int: + """ + Get the device memory on device belonging to the current process. + + Args: + device (torch.device): a device + + Returns: + int: memory size in bytes + """ + if device.type == 'cpu': + mem_info = _get_cpu_memory_info() + # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used. + # Each process consumes the same amount of memory. + ret = mem_info.used / gpc.num_processes_on_current_node + return ret + elif device.type == 'cuda': + ret: int = torch.cuda.memory_allocated(device) + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats(device) + return ret + + +def colo_set_process_memory_fraction(ratio: float) -> None: + """colo_set_process_memory_fraction + + set how much cuda memory used on the gpu belonging to the current process. + + Args: + ratio (float): a ratio between 0. ~ 1. + """ + if version.parse(torch.__version__) < version.parse('1.8'): + logger = get_dist_logger('colo_set_process_memory_fraction') + logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8') + return + global _GLOBAL_CUDA_MEM_FRACTION + _GLOBAL_CUDA_MEM_FRACTION = ratio + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + + +def colo_set_cpu_memory_capacity(size: int) -> None: + global _GLOBAL_CPU_MEM_CAPACITY + mem_info = _get_cpu_memory_info() + total_size = mem_info.total + if size <= total_size: + _GLOBAL_CPU_MEM_CAPACITY = size + else: + _GLOBAL_CPU_MEM_CAPACITY = total_size + + +def colo_get_cpu_memory_capacity() -> int: + """ + Get the cpu memory capacity. We may not use all of it. + Returns: + int: _description_ + """ + global _GLOBAL_CPU_MEM_CAPACITY + if _GLOBAL_CPU_MEM_CAPACITY == -1: + mem_info = _get_cpu_memory_info() + return mem_info.total + else: + return _GLOBAL_CPU_MEM_CAPACITY diff --git a/colossalai/utils/model/__init__.py b/colossalai/utils/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..93c91e0995ea106c33d05878dfdc9bb32ef64a31 --- /dev/null +++ b/colossalai/utils/model/colo_init_context.py @@ -0,0 +1,196 @@ +from typing import Any, Dict, Iterator, Optional, Tuple, Union + +import torch +from torch import nn + +from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup + +from .utils import InsertPostInitMethodToModuleSubClasses + +# find named_params includes replica + + +def _named_params_with_replica( + module: nn.Module, + prefix: str = '', + recurse: bool = True, +) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + for mod_prefix, mod in modules: + for name, val in mod._parameters.items(): + if val is None: + continue + name = mod_prefix + ('.' if mod_prefix else '') + name + yield name, val + + +def _convert_to_coloparam(param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None) -> ColoParameter: + + if isinstance(param, ColoParameter): + return param + # detaching tensor is necessary for optimizers. + requires_grad = param.requires_grad + # param is the global tensor. + + if param.device.type == "meta": + colo_param = ColoParameter(param, requires_grad=requires_grad) + else: + colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) + + + # if default_shard_plan exists, shard the param during initialization. + # This can reduce the model size after initialization. + # NOTE() embedding usually can not be correctly sharded. So I use except to handle + # the param that can not be sharded by the default plan + if default_pg is not None: + colo_param.set_process_group(default_pg) + + if default_dist_spec is not None: + try: + colo_param.set_dist_spec(default_dist_spec) + except: + pass + return colo_param + + +def ColoModulize(module): + """ + Replacing the parameters() and named_parameters() with our customized ones + """ + + module._colo_visited = True + + +class ColoInitContext(InsertPostInitMethodToModuleSubClasses): + + def __init__(self, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): + """ + Args: + device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). + dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float. + default_pg (ProcessGroup): the default process group for all initialized parameters. + default_dist_spec: the default distributed specifications. + """ + super().__init__() + self._device = device + self._dtype = dtype + + self._register_colo_modules() + self._default_pg = default_pg + self._default_dist_spec = default_dist_spec + + def _register_colo_modules(self): + register_colo_module(torch.nn.Linear, ColoLinear()) + register_colo_module(torch.nn.Embedding, ColoEmbedding()) + + def _pre_context_exec(self): + pass + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + FIXME(fjr) The module may be passed to this function multiple times? + """ + name_list = [] + for name, param in _named_params_with_replica(module): + if isinstance(param, ColoTensor): + continue + + split = name.rfind('.') + if split >= 0: # param in submodule + module_name = name[:split] + param_name = name[split + 1:] + else: + module_name = '' # param in current module + param_name = name + name_list.append((module_name, param_name)) + + replaced_tensors = dict( + ) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + for module_name, param_name in name_list: + submodule = module.get_submodule(module_name) + param = submodule.get_parameter(param_name) + if param in replaced_tensors: + colo_param = replaced_tensors[param] + else: + colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, + self._default_dist_spec) + replaced_tensors[param] = colo_param + delattr(submodule, param_name) + setattr(submodule, param_name, colo_param) + colo_param.shared_param_modules.append(submodule) + + meta_param_flag = 0 + meta_buffer_flag = 0 + for param in module.parameters(): + if param.device.type=="meta": + meta_param_flag = 1 + if meta_param_flag == 1 and param.device.type!="meta": + raise ValueError("Meta parameters and valued parameters can not be in the same model") + + for buffer in module.buffers(): + if buffer.device.type=="meta": + meta_buffer_flag = 1 + if meta_buffer_flag == 1 and buffer.device.type!="meta": + raise ValueError("Meta buffers and valued buffers can not be in the same model") + + if meta_param_flag==1 and meta_buffer_flag==1: + pass + elif meta_buffer_flag==0 and meta_param_flag==1: + for name, buf in module.named_buffers(): + module._buffers[name] = module._buffers[name].to(device=self._device) + elif meta_param_flag==0 and meta_buffer_flag==1: + for name, param in module.named_parameters(): + module._parameters[name] = module._parameters[name].to(device=self._device) + else: + module.to(self._device) + + +def post_process_colo_init_ctx(model: torch.nn.Module, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): + """post_process_colo_init_ctx + + This function is called after `ColoInitContext`. + + Args: + model (torch.nn.module): the model + device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu'). + dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float. + default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Inidicates a DP-only process group. + default_dist_spec (Any, optional): default dist spec of params. Defaults to None. + + Raises: + RuntimeError: raise error if + """ + + torch_params = [] + for n, p in model.named_parameters(): + if not isinstance(p, ColoParameter): + # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") + torch_params.append((n, p)) + + for (n, param) in torch_params: + name_list = n.split('.') + module = model + for i in range(len(name_list) - 1): + module = module._modules[name_list[i]] + delattr(module, name_list[-1]) + setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec)) + + del torch_params + for n, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + raise RuntimeError diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..cf05f966089d16884166469cff299e6991192097 --- /dev/null +++ b/colossalai/utils/model/lazy_init_context.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# coding: utf-8 + +import inspect +import types +from typing import Callable, List + +import torch +import torch.nn as nn + +from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.utils.model.utils import substitute_init_recursively + + +class LazyInitContext(): + """ + A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor + initialization functions for lazy initialization + + Note: + This API is only experimental and subject to future changes. + + Usage: + with LazyInitContext() as ctx: + model = nn.Linear(10, 10) + model.weight.zero_() + + # make sure the weight is a meta tensor + assert model.weight.is_meta + + # initialize weights + ctx.lazy_init_parameters(model) + + # make sure the weight is not a meta tensor + # and initialized correctly + assert not model.weight.is_meta and torch.all(model.weight == 0) + + Args: + to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This + argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. + extra_torch_tensor_func (List[str]): extra torch tensor functions related + to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. + """ + + tensor_set_value_func = ['zero_', 'fill_'] + + def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): + # TODO: hijack the torch constructor functions as well + self._to_meta = to_meta + self._intercepted_nn_init_func_cache = {} + self._nn_init_methods = self._get_nn_init_methods() + self._torch_mod_cls = torch.nn.modules.module.Module + + if extra_torch_tensor_func: + # use tuple to remove duplicates + self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func) + else: + self._torch_tensor_funcs = self.tensor_set_value_func + + @property + def to_meta(self): + return self._to_meta + + def _cache_init_func(self, func): + """ + This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions + so that the function call is cached instead of being executed. + """ + + def wrapped_init_func(tensor, *args, **kwargs): + if tensor not in self._intercepted_nn_init_func_cache: + self._intercepted_nn_init_func_cache[tensor] = [] + self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs)) + + return wrapped_init_func + + def _get_nn_init_methods(self): + """ + This method looks for all available functions in the ``torch.nn.init`` + module. + """ + nn_init_method_names = dir(torch.nn.init) + nn_init_methods = [] + + # look for all methods in ``torch.nn.init`` module + for name in nn_init_method_names: + nn_init_methods.append((name, getattr(torch.nn.init, name))) + + def _is_init_method(item): + name, func = item + + if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')): + return False + else: + return True + + # remove methods which are not init functions + nn_init_methods = list(filter(_is_init_method, nn_init_methods)) + return nn_init_methods + + def _wrap_module_init(self, func): + """ + This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces + the argument device with value 'meta' so that all modules are created as meta tensors. + """ + has_device = 'device' in inspect.signature(func).parameters + + def layer_lazy_init(module, *args, **kwargs): + # if this module contains device argument + # we set it to meta to initialize as meta backend + if has_device: + kwargs['device'] = 'meta' + func(module, *args, **kwargs) + + # if device is not found, we intialize it and convert to meta + if not has_device: + module.to('meta') + + return layer_lazy_init + + def _get_tmp_origin_func_ref(self, name): + """ + Generate a function name for consistency during caching and retrieving. + """ + return f'_orig_{name}' + + def _patch_nn_init_funcs(self): + # patch nn.init functions + for name, func in self._nn_init_methods: + setattr(torch.nn.init, name, self._cache_init_func(func)) + + def _unpatch_nn_init_funcs(self): + # unpatch nn.init functions + for name, func in self._nn_init_methods: + setattr(torch.nn.init, name, func) + + def _patch_submodule_init(self): + # patch classes __init__ methods + def _activate_wrap_init(cls): + cls.__orig_init__ = cls.__init__ + cls.__init__ = self._wrap_module_init(cls.__init__) + + substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set()) + + def _unpatch_submodule_init(self): + + def _recover_orig_init(cls): + cls.__init__ = cls.__orig_init__ + + substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set()) + + def _patch_torch_tensor_funcs(self): + # patch tensor value-setting functions + for func_name in self._torch_tensor_funcs: + origin_func_name = self._get_tmp_origin_func_ref(func_name) + origin_func = getattr(torch.Tensor, func_name) + setattr(torch.Tensor, origin_func_name, origin_func) + setattr(torch.Tensor, func_name, self._cache_init_func(origin_func)) + + def _unpatch_torch_tensor_funcs(self): + for func_name in self._torch_tensor_funcs: + origin_func_name = self._get_tmp_origin_func_ref(func_name) + origin_func = getattr(torch.Tensor, origin_func_name) + setattr(torch.Tensor, func_name, origin_func) + + def __enter__(self): + self._patch_torch_tensor_funcs() + self._patch_nn_init_funcs() + + if self._to_meta: + self._patch_submodule_init() + return self + + def __exit__(self, *args, **kwargs): + if self._to_meta: + self._unpatch_submodule_init() + self._unpatch_nn_init_funcs() + self._unpatch_torch_tensor_funcs() + + def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'): + """ + Initialize the weights of the meta-tensor model. + + Args: + model (`torch.nn.Module`): the model instantiated under the context. + device (str): the device on which weights are initialized + + """ + + def _init_recursively(module: nn.Module): + # recursively initialize the module + for mod in module.children(): + _init_recursively(mod) + + # initialize and shard tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + _init_and_shard(module, name, param) + + for name, buf in module.named_buffers(recurse=False): + _init_and_shard(module, name, buf) + + @torch.no_grad() + def _init_and_shard(module, name, tensor): + # check whether the tensor is a buffer or parameter + is_param = isinstance(tensor, nn.parameter.Parameter) + + # get sharding spec + dist_spec = getattr(tensor, 'dist_spec', None) + pg = getattr(tensor, 'pg', None) + comp_spec = getattr(tensor, 'comp_spec', None) + + # convert the tensor from meta to materialized one + if tensor.is_meta: + materialized_tensor = torch.empty_like(tensor, device=device) + # if this tensor is a meta tensor, it must have an init function + assert tensor in self._intercepted_nn_init_func_cache + else: + materialized_tensor = tensor + + # apply init function + if tensor in self._intercepted_nn_init_func_cache: + init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] + init_func(materialized_tensor, *args, **kwargs) + + # convert it to ColoTensor or ColoParameter + if is_param: + tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) + else: + tensor = ColoTensor.from_torch_tensor(materialized_tensor) + + # override the original tensor + with torch.no_grad(): + setattr(module, name, tensor) + + # apply sharding + if dist_spec: + tensor.process_group = pg + tensor.set_tensor_spec(dist_spec, comp_spec) + + _init_recursively(model) + + return model diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..75bb18df66c14ad5097de2ce69efc543d405195d --- /dev/null +++ b/colossalai/utils/model/utils.py @@ -0,0 +1,110 @@ +import torch +import functools +from typing import Optional + + +def substitute_init_recursively(cls, func, visited: set): + for subcls in cls.__subclasses__(): + substitute_init_recursively(subcls, func, visited) + if subcls not in visited: + func(subcls) + visited.add(subcls) + + +def call_to_str(base, *args, **kwargs): + """Construct a string representation of a call. + + Args: + base (str): name of the call + args (tuple, optional): args to ``base`` + kwargs (dict, optional): kwargs supplied to ``base`` + + Returns: + str: A string representation of base(*args, **kwargs) + """ + name = f'{base}(' + if args: + name += ', '.join(repr(arg) for arg in args) + if kwargs: + name += ', ' + if kwargs: + name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) + name += ')' + return name + + +class InsertPostInitMethodToModuleSubClasses(object): + + def __init__(self, default_dtype: Optional[torch.dtype] = None): + self._old_default_dtype = None + self._default_dtype = default_dtype + + def __enter__(self): + r""" + Enter the context scope. + """ + if self._default_dtype is not None: + self._old_default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._default_dtype) + + def preprocess_after(f): + + @functools.wraps(f) + def wrapper(module: torch.nn.Module, *args, **kwargs): + f(module, *args, **kwargs) + self._post_init_method(module, *args, **kwargs) + + return wrapper + + def _enable_class(cls): + cls._old_init = cls.__init__ + cls.__init__ = preprocess_after(cls.__init__) + + # The function is called during init subclass. + def _init_subclass(cls, **kwargs): + cls.__init__ = preprocess_after(cls.__init__) + + # Replace .__init__() for all existing subclasses of torch.nn.Module + # Excution self._post_init_method after the default init function. + substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) + + # holding on to the current __init__subclass__ for exit + torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + + self._pre_context_exec() + return self + + def __exit__(self, exc_type, exc_value, traceback): + + if self._default_dtype is not None: + torch.set_default_dtype(self._old_default_dtype) + + def _disable_class(cls): + if not hasattr(cls, '_old_init'): + raise AttributeError( + f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." + ) + cls.__init__ = cls._old_init + + # Replace .__init__() for all existing subclasses of torch.nn.Module + substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set()) + + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + + self._post_context_exec() + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module, *args, **kwargs): + pass + + def _pre_context_exec(self): + pass + + def _post_context_exec(self): + pass diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..7c1304fd4c373baeeb330256b0e0066b498533b2 --- /dev/null +++ b/colossalai/utils/moe.py @@ -0,0 +1,52 @@ +import torch.nn as nn +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.context import ParallelMode +from .common import is_using_ddp +from typing import Dict, List + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not hasattr(param, 'moe_info'): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = param.moe_info.ep_size + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + if is_using_ddp(): + + param_dict = get_moe_epsize_param_dict(model) + + # synchrosize the parameters whose dp_group is the whole world + if 1 in param_dict: + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in param_dict[1]: + dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) + for param in param_dict[ep_size]: + dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/colossalai/utils/multi_tensor_apply/__init__.py b/colossalai/utils/multi_tensor_apply/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94d13b339a0de106334fee99a8b73ea6e70f60dd --- /dev/null +++ b/colossalai/utils/multi_tensor_apply/__init__.py @@ -0,0 +1,3 @@ +from .multi_tensor_apply import MultiTensorApply + +multi_tensor_applier = MultiTensorApply(2048 * 32) diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..6eda9834bdd30d4cf7b73fbc367cebd571b6459a --- /dev/null +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -0,0 +1,34 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/multi_tensor_apply/multi_tensor_apply.py + + +class MultiTensorApply(object): + """ + Apply an operation to a list of tensors efficiently. + + Args: + chunk_size (int): Size of a chunk. + """ + + available = False + warned = False + + def __init__(self, chunk_size): + try: + import colossalai._C.fused_optim + MultiTensorApply.available = True + self.chunk_size = chunk_size + except ImportError as err: + MultiTensorApply.available = False + MultiTensorApply.import_err = err + + def check_avail(self): + if not MultiTensorApply.available: + raise RuntimeError( + "Attempted to call MultiTensorApply method, but MultiTensorApply " + "is not available, possibly because Apex was installed without " + "--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err) + + def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + self.check_avail() + + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/utils/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90eab67c4d27d7b6e76316be097290c9fc7fa5ce --- /dev/null +++ b/colossalai/utils/profiler/__init__.py @@ -0,0 +1,2 @@ +from .legacy import * +from .profiler import profile diff --git a/colossalai/utils/profiler/extention.py b/colossalai/utils/profiler/extention.py new file mode 100644 index 0000000000000000000000000000000000000000..6726a683cc05ebb1ac5370d8c17750cd869d9ec2 --- /dev/null +++ b/colossalai/utils/profiler/extention.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ProfilerExtension(ABC): + + @abstractmethod + def prepare_trace(self): + pass + + @abstractmethod + def start_trace(self): + pass + + @abstractmethod + def stop_trace(self): + pass + + @abstractmethod + def extend_chrome_trace(self, trace: dict) -> dict: + pass diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/utils/profiler/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6f878407039e368c9dc75660e0f1239a32d049 --- /dev/null +++ b/colossalai/utils/profiler/legacy/__init__.py @@ -0,0 +1,6 @@ +from .comm_profiler import CommProfiler +from .pcie_profiler import PcieProfiler +from .prof_utils import ProfilerContext, BaseProfiler +from .mem_profiler import MemProfiler + +__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/utils/profiler/legacy/comm_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..279fe311a18757844c7a1a8075253b7b4e00033e --- /dev/null +++ b/colossalai/utils/profiler/legacy/comm_profiler.py @@ -0,0 +1,308 @@ +import inspect +from pathlib import Path +from functools import partial +import torch +from torch.autograd.profiler import profile +import torch.distributed as dist +from torch.distributed import ReduceOp +from colossalai.utils import get_current_device +from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth +from typing import List, Optional + + +def _get_code_location(depth: int): + ret = [] + length = min(len(inspect.stack()), depth + 1) + for i in range(3, length): + upper_frame = inspect.stack()[i] + function_name = inspect.stack()[i - 1].function + ret.append(upper_frame.filename) + ret.append('(') + ret.append(str(upper_frame.lineno)) + ret.append('): ') + ret.append(function_name) + if i != length - 1: + ret.append('\n') + + return ''.join(ret) + + +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + +class CommEvent(object): + """Communication Event. Used for communication time and communication + volume recording. + """ + + def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + self.self_count = count + self.self_comm_vol = comm_vol + self.self_cuda_time = cuda_time + + def add(self, rhs): + self.self_count += rhs.self_count + self.self_comm_vol += rhs.self_comm_vol + self.self_cuda_time += rhs.self_cuda_time + + +class CommProfiler(BaseProfiler): + """Communication profiler. Records all communication events. + """ + + def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): + super().__init__(profiler_name="Collective_Communication", priority=0) + self.depth = 3 + depth + self.total_count = total_count + self.total_comm_vol = total_comm_vol + self.total_cuda_time = total_cuda_time + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def reset(self): + self.total_count = 0 + self.total_comm_vol = 0 + self.total_cuda_time = 0 + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def enable(self): + dist.all_reduce = partial(all_reduce, profiler=self) + dist.all_gather = partial(all_gather, profiler=self) + dist.reduce_scatter = partial(reduce_scatter, profiler=self) + dist.broadcast = partial(broadcast, profiler=self) + dist.reduce = partial(reduce, profiler=self) + + def disable(self): + dist.all_reduce = torch_all_reduce + dist.all_gather = torch_all_gather + dist.reduce_scatter = torch_reduce_scatter + dist.broadcast = torch_broadcast + dist.reduce = torch_reduce + + def to_tensorboard(self, writer): + writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + if self.warn_flag: + append("Warnning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate.") + + if self.total_cuda_time == 0: + return "No collective communication has been called yet!" + + append("Collective communication profiling result:") + append("total cuda time: {}".format(_format_time(self.total_cuda_time))) + append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) + append("total number of calls: {}".format(self.total_count)) + append("All events:") + + seperation = '-' * 74 + row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 + + append(seperation) + append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) + append(seperation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.self_cuda_time), + '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) + append() + + return ''.join(res) + + @property + def has_aync_op(self): + return self.pending_op is not None + + def activate_profiler(self, kn: str, vol: float): + self.pending_metadata = (kn, _get_code_location(self.depth), vol) + self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) + self.profiler.__enter__() + + def close_profiler(self, group=None): + assert self.profiler is not None, "There is no running dist op" + kernel_name, code_location, vol = self.pending_metadata + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled and dist.get_world_size(group) > 1: + assert_flag = 0 + current_comm_event = None + events = self.profiler.function_events + for event in events: + if kernel_name in event.name: + assert assert_flag == 0, "Multiple dist ops has been called " + current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) + assert_flag += 1 + + assert current_comm_event is not None, "dist op has not been found" + + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) + current_comm_event.self_cuda_time = buffer.item() + + self.total_count += current_comm_event.self_count + self.total_comm_vol += current_comm_event.self_comm_vol + self.total_cuda_time += current_comm_event.self_cuda_time + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + + self.profiler = None + self.pending_op = None + self.pending_metadata = None + + def wait_async_op(self): + if self.pending_op is not None: + op = self.pending_op + op.wait() + self.close_profiler() + + +class CommHandler(object): + """Communication handler. A dummy handler to wait aync operations. + """ + + def __init__(self, profiler: CommProfiler): + super().__init__() + self.prof = profiler + + def wait(self): + self.prof.wait_async_op() + + +def async_check(profiler: CommProfiler): + if profiler.pending_op is not None: + profiler.warn_flag = True + profiler.wait_async_op() + + +def all_reduce(tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = 2 * (comm_size - 1) / comm_size + comm_vol = correction * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) + profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce_scatter(output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for tensor in input_list: + comm_vol += tensor.element_size() * tensor.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def all_gather(tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for ten in tensor_list: + comm_vol += ten.element_size() * ten.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) + profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def broadcast(tensor: torch.Tensor, + src: int, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) + profiler.pending_op = torch_broadcast(tensor, src, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce(tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) + profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/utils/profiler/legacy/pcie_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..7df0f7043006b90a3f733a858b31fb1180256b1d --- /dev/null +++ b/colossalai/utils/profiler/legacy/pcie_profiler.py @@ -0,0 +1,148 @@ +from pathlib import Path +from torch.autograd.profiler import profile +from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth +from typing import List + + +def _get_size(dtype: str): + if dtype == "fp16": + return 2 + elif dtype == "fp32": + return 4 + else: + raise NotImplementedError + + +def _get_numel(my_list: List[int]) -> int: + from functools import reduce + from operator import mul + return reduce(mul, my_list) + + +def _reduce_location(locations: List[str]) -> str: + ret = [] + for lo in locations: + ret.append(lo) + ret.append("\n") + ret = ret[:-1] + return ''.join(ret) + + +class PcieEvent(object): + """Pcie Event. + """ + + def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): + self.count = count + self.pcie_vol = pcie_vol + self.cuda_time = cuda_time + + def add(self, rhs): + self.count += rhs.count + self.pcie_vol += rhs.pcie_vol + self.cuda_time += rhs.cuda_time + + +class PcieProfiler(BaseProfiler): + """Pcie profiler. Records all data transmission between CPU and GPU. + + TODO: Merge pcie profiler into communication profiler + """ + + def __init__(self, dtype: str = "fp32", depth: int = 1): + super().__init__(profiler_name="Pcie", priority=10) + self.depth = depth + self.data_size = _get_size(dtype) + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def reset(self): + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def enable(self): + self.profiler = profile(enabled=True, + use_cuda=True, + use_cpu=True, + use_kineto=True, + record_shapes=True, + with_stack=True) + self.profiler.__enter__() + + def disable(self): + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled: + events = self.profiler.function_events + for event in events: + if event.name == "aten::copy_": + t_shape = event.input_shapes[0] + if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: + continue + current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) + code_location = _reduce_location(event.stack[:self.depth]) + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + elif 'Memcpy HtoD' in event.name: + self.h2d_count += 1 + self.h2d_time += event.cuda_time_total + elif 'Memcpy DtoH' in event.name: + self.d2h_count += 1 + self.d2h_time += event.cuda_time_total + + self.profiler = None + + def to_tensorboard(self, writer): + writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + append("Pcie profiling result:") + append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) + append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) + append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) + append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) + + append("Possible data transmission events in PCIE:") + + seperation = '-' * 62 + row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 + + append(seperation) + append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) + append(seperation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) + append() + + return ''.join(res) diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/utils/profiler/legacy/prof_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df05c7116c6a6fc0a8e3e2644e373d6d02dfdcb2 --- /dev/null +++ b/colossalai/utils/profiler/legacy/prof_utils.py @@ -0,0 +1,131 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Union, List +from colossalai.core import global_context as gpc + + +# copied from high version pytorch to support low version +def _format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return '{:.3f}s'.format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return '{:.3f}ms'.format(time_us / US_IN_MS) + return '{:.3f}us'.format(time_us) + + +# copied from high version pytorch to support low version +def _format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if (abs(nbytes) >= GB): + return '{:.2f} GB'.format(nbytes * 1.0 / GB) + elif (abs(nbytes) >= MB): + return '{:.2f} MB'.format(nbytes * 1.0 / MB) + elif (abs(nbytes) >= KB): + return '{:.2f} KB'.format(nbytes * 1.0 / KB) + else: + return str(nbytes) + ' B' + + +def _format_bandwidth(volme: float or int, time_us: int): + sec_div_mb = (1000.0 / 1024.0)**2 + mb_per_sec = volme / time_us * sec_div_mb + + if mb_per_sec >= 1024.0: + return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) + else: + return '{:.3f} MB/s'.format(mb_per_sec) + + +class BaseProfiler(ABC): + + def __init__(self, profiler_name: str, priority: int): + self.name = profiler_name + self.priority = priority + + @abstractmethod + def enable(self): + pass + + @abstractmethod + def disable(self): + pass + + @abstractmethod + def to_tensorboard(self, writer): + pass + + @abstractmethod + def to_file(self, filename: Path): + pass + + @abstractmethod + def show(self): + pass + + +class ProfilerContext(object): + """Profiler context manager + + Usage:: + + world_size = 4 + inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + + cc_prof = CommProfiler() + + with ProfilerContext([cc_prof]) as prof: + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + prof.show() + """ + + def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): + self.enable = enable + self.profilers = sorted(profilers, key=lambda prof: prof.priority) + + def __enter__(self): + if self.enable: + for prof in self.profilers: + prof.enable() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + for prof in self.profilers: + prof.disable() + + def to_tensorboard(self, writer): + from torch.utils.tensorboard import SummaryWriter + + assert isinstance(writer, SummaryWriter), \ + f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + + for prof in self.profilers: + prof.to_tensorboard(writer) + + def to_file(self, log_dir: Union[str, Path]): + if isinstance(log_dir, str): + log_dir = Path(log_dir) + + if not log_dir.exists(): + log_dir.mkdir(parents=True, exist_ok=True) + for prof in self.profilers: + log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + prof.to_file(log_file) + + def show(self): + for prof in self.profilers: + prof.show() diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..8f43a0b96de0a8deb4818644bb48df7ddaad896c --- /dev/null +++ b/colossalai/utils/profiler/profiler.py @@ -0,0 +1,201 @@ +import os +from typing import List +from colossalai.engine import Engine +from torch.profiler import profile as torch_profile +from torch.profiler.profiler import ProfilerAction +from typing import Any, Callable, Iterable, Optional +from torch.autograd import ProfilerActivity +import json +import os +import tempfile +import gzip +from colossalai.utils.profiler.extention import ProfilerExtension +from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention +from colossalai.logging import get_dist_logger + + +class profile(torch_profile): + """Profiler context manager. + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. + schedule (callable): callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action to perform at each step. + on_trace_ready (callable): callable that is called at each step when ``schedule`` + returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. + engine (Optional[Engine], optional): An ``Engine`` instance. Defaults to None. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation. + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + profile_stateful_tensor_memory (bool): track stateful tensor memory usage. ``engine`` must not be None if you enable this. + + .. note:: + Use :func:`~torch.profiler.schedule` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: + + ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` + + After profiling, result files can be found in the specified directory. Use the command: + + ``tensorboard --logdir dir_name`` + + to see the results in TensorBoard. + For more information, see + `PyTorch Profiler TensorBoard Plugin `__ + + .. note:: + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + + # In this example with wait=1, warmup=1, active=2, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') + # used when outputting for tensorboard + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() + """ + + def __init__(self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + engine: Optional[Engine] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + profile_stateful_tensor_memory: bool = False) -> None: + super().__init__(activities=activities, + schedule=schedule, + on_trace_ready=on_trace_ready, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules) + self._logger = get_dist_logger() + self.extentions: List[ProfilerExtension] = [] + if profile_stateful_tensor_memory: + if engine is None: + self._logger.warning('Ignore "profile_model_data" since engine is None', ranks=[0]) + else: + self.extentions.append(StatefulTensorMemoryProfilerExtention(engine)) + + def prepare_trace(self) -> None: + if hasattr(super(), 'prepare_trace'): + super().prepare_trace() + elif hasattr(super(), '_start_warmup'): + super()._start_warmup() + for ext in self.extentions: + ext.prepare_trace() + + def _start_warmup(self): + self.prepare_trace() + + def start_trace(self): + if hasattr(super(), '_start_trace'): + super()._start_trace() + elif hasattr(super(), 'start_trace'): + super().start_trace() + for ext in self.extentions: + ext.start_trace() + + def _start_trace(self): + self.start_trace() + + def stop_trace(self): + if hasattr(super(), '_stop_trace'): + super()._stop_trace() + elif hasattr(super(), 'stop_trace'): + super().stop_trace() + for ext in self.extentions: + ext.stop_trace() + + def _stop_trace(self): + self.stop_trace() + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. + """ + assert self.profiler + fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + fp.close() + retvalue = self.profiler.export_chrome_trace(fp.name) + with open(fp.name) as fin: + trace = json.load(fin) + for ext in self.extentions: + trace = ext.extend_chrome_trace(trace) + open_func = gzip.open if path.endswith('.gz') else open + with open_func(path, 'wt') as fout: + json.dump(trace, fout) + + os.remove(fp.name) + return retvalue diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py new file mode 100644 index 0000000000000000000000000000000000000000..127055c8c1efa5e193ee04cbf6467df35a5972a7 --- /dev/null +++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py @@ -0,0 +1,133 @@ +import os +import threading +import time +import torch +from enum import Enum +from typing import List +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.gemini.ophooks import BaseOpHook +from colossalai.engine import Engine +from colossalai.utils.profiler.extention import ProfilerExtension + + +class DeviceType(Enum): + CPU = 0 + CUDA = 1 + + +def get_timestamp_us(): + return int(time.time() * 1e6) + + +def generic_instant_event(name, pid, tid, timestamp, args): + return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} + + +class StatefulTensorMemoryEvent: + EVENT_NAME = '[statefulTensorMemory]' + + def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: + self.pid = os.getpid() + self.tid = threading.get_ident() + self.timestamp = timestamp + self.device_type = device_type + self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1 + self.bytes = bytes_ + + def state_dict(self): + return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { + 'Device Type': self.device_type.value, + 'Device Id': self.device_id, + 'Bytes': self.bytes + }) + + +class StatefulTensorMemoryTracer: + + def __init__(self) -> None: + self.events: List[StatefulTensorMemoryEvent] = [] + self._tracing = False + + def sample(self): + cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + timestamp = get_timestamp_us() + if self._tracing: + self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) + self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem)) + + def start_trace(self): + self.events.clear() + self._tracing = True + + def stop_trace(self): + self._tracing = False + + def state_dict(self): + return [event.state_dict() for event in self.events] + + +class StatefulTensorMemoryTracerHook(BaseOpHook): + + def __init__(self, tracer: StatefulTensorMemoryTracer): + super().__init__() + self.tracer = tracer + self._enable = False + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + if self._enable: + self.tracer.sample() + + def post_fwd_exec(self, module: torch.nn.Module, *args): + if self._enable: + self.tracer.sample() + + def pre_bwd_exec(self, module: torch.nn.Module, input_, output): + if self._enable: + self.tracer.sample() + + def post_bwd_exec(self, module: torch.nn.Module, input_): + if self._enable: + self.tracer.sample() + + def post_iter(self): + if self._enable: + self.tracer.sample() + + def enable(self): + self._enable = True + + def disable(self): + self._enable = False + + +class StatefulTensorMemoryProfilerExtention(ProfilerExtension): + + def __init__(self, engine: Engine) -> None: + self.engine = engine + self.tracer = StatefulTensorMemoryTracer() + self.hook = StatefulTensorMemoryTracerHook(self.tracer) + self.hook_registered = False + + def prepare_trace(self): + self.hook.enable() + if not self.hook_registered: + self.engine.add_hook(self.hook) + self.hook_registered = True + + def start_trace(self): + self.prepare_trace() + self.tracer.start_trace() + + def stop_trace(self): + self.tracer.stop_trace() + self.hook.disable() + if self.hook_registered: + self.engine.remove_hook(self.hook) + # remove_hook is not implemented now + # FIXME(ver217): uncomment below line when remove_hook is implemented + # self.hook_registered = False + + def extend_chrome_trace(self, trace: dict) -> dict: + trace['traceEvents'].extend(self.tracer.state_dict()) + return trace diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e30a925d2a9291de9d8eeb01119237d24a1bc38c --- /dev/null +++ b/colossalai/utils/rank_recorder/README.md @@ -0,0 +1,72 @@ +# Rank Recorder +This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily. + +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. + +## Usage + +Is very simple: + +```python +from colossalai.utils.rank_recorder import recorder + +... +... + +with recorder(record_name, current_rank) as r: + """procedure to record + """ + +``` + +## Example +This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank. + +```python +import time +import os +import logging +logging.disable(logging.INFO) + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from colossalai.utils.rank_recorder import recorder + + +WORLD_SIZE = 4 + +# config the export image here +# If you want to dive into the detail, format 'svg' is recommended +recorder.export_format = 'png' +recorder.export_name = 'kernel_select' +recorder.dpi = 500 + +def calc(x, y): + a = torch.randn(x, y).cuda() + b = torch.randn(x, y).cuda() + c = sum(a * b) + return c + +def worker(rank): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29020' + dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank) + print(dist.get_rank(), "enter") + time.sleep(0.1 * rank) + + with recorder("calc_1(x100)", rank) as r: + calc(100, 100) + + with recorder("calc_2(x400)", rank) as r: + calc(400, 400) + + with recorder("calc_2(x200)", rank) as r: + calc(200, 200) + +if __name__ == "__main__": + mp.spawn(worker, nprocs=WORLD_SIZE) +``` + +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1274d0e7dbc5277a60fb4d6fcd16c972b85c305c --- /dev/null +++ b/colossalai/utils/rank_recorder/__init__.py @@ -0,0 +1,3 @@ +from colossalai.utils.rank_recorder.rank_recorder import recorder + +__all__ = ["recorder"] \ No newline at end of file diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..c088ceeb2e87727ca9f4e5ac4a71d0721405f4a1 --- /dev/null +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -0,0 +1,178 @@ +import time +from typing import List, Dict +import json +import os +import time +import shutil +import atexit + +import torch +import torch.distributed as dist + +import json +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors + +cmap = list(mcolors.TABLEAU_COLORS.values()) + +LOG_FOLDER = "record.log" +MAX_WAIT_TIME = 20 + + +class Event: + + def __init__(self, start: int, end: int, name: str, rank: int) -> None: + self.start = start + self.end = end + self.name = name + self.rank = rank + + +class Recorder: + + def __init__(self) -> None: + self.rank_to_history: Dict[int, List[Event]] = {} + self.base_time = time.time() + self.temp_event = None + + self.export_format = 'png' + self.export_name = 'test' + self.dpi = 500 + self.theme = 'dark_background' + self.figure_width = 30 + self.figure_height = 10 + self.legend_fontsize = 16 + self.device_fontsize = 20 + self.bar_height = 0.2 + + if not os.path.exists(LOG_FOLDER): + os.makedirs(LOG_FOLDER) + + def start(self, name: str, rank: int): + # TODO : add lock to prevent conflict + torch.cuda.synchronize() + start_time = time.time() + self.temp_event = Event(start_time, None, name, rank) + + def end(self): + assert self.temp_event is not None, "`start` before `end`" + torch.cuda.synchronize() + end_time = time.time() + self.temp_event.end = end_time + rank = self.temp_event.rank + if rank not in self.rank_to_history: + self.rank_to_history[rank] = [] + self.rank_to_history[rank].append(self.temp_event) + self.temp_event = None + + def get_history(self): + return self.history + + def __call__(self, name: str, rank: str): + self.temp_name = name + self.temp_rank = rank + return self + + def __enter__(self): + name = self.temp_name + rank = self.temp_rank + self.start(name, rank) + + def __exit__(self, *args): + self.end() + + def dump_record(self): + rank = dist.get_rank() + rank_to_history = self.rank_to_history + records = {'base_time': self.base_time, 'content': {}} + for record_rank in rank_to_history: + history = rank_to_history[record_rank] + recs = [] + for event in history: + rec = {'start': event.start, 'end': event.end, 'name': event.name} + recs.append(rec) + records['content'][record_rank] = recs + + dump_name = f'{rank}.json' + dump_path = os.path.join(LOG_FOLDER, dump_name) + with open(dump_path, 'w', encoding='utf-8') as f: + json.dump(records, f, ensure_ascii=False) + + def merge_recode(self): + base_time = self.base_time + world_size = dist.get_world_size() + + wait_time = 0 + while True: + time.sleep(0.1) + log_num = len(os.listdir(LOG_FOLDER)) + if log_num == world_size: + break + + wait_time += 1 + if wait_time >= MAX_WAIT_TIME: + break + + # merge + logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)] + recoders = {} + for path in logs_path: + with open(path, 'r', encoding='utf-8') as f: + recs = json.load(f) + for record_rank in recs['content']: + history = recs['content'][record_rank] + recoders[record_rank] = [] + for rec in history: + recoders[record_rank].append({ + 'start': rec['start'] - base_time, + 'end': rec['end'] - base_time, + 'name': rec['name'] + }) + + shutil.rmtree(LOG_FOLDER) + with open(self.export_name + '.json', 'w', encoding='utf-8') as f: + json.dump(recoders, f, ensure_ascii=False) + + def visualise_record(self): + with open(self.export_name + '.json', 'r', encoding='utf-8') as f: + records = json.load(f) + records = dict(records) + ranks = list(sorted(records.keys())) + + name_list = {} + plots = {} + plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height]) + plt.style.use(self.theme) + + for rank in ranks: + rank_records = records[rank] + for rec in rank_records: + s = rec['start'] + e = rec['end'] + name = rec['name'] + if name not in name_list: + name_list[name] = len(name_list) + bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]]) + if name not in plots: + plots[name] = bar + + plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize) + plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize) + plt.grid(axis='x') + plt.savefig("{}.{}".format(self.export_name, self.export_format)) + + def exit_worker(self): + if len(self.rank_to_history) == 0: + return + self.dump_record() + # if this is rank 0, wait for merge + rank = dist.get_rank() + + if rank == 1: + # take the base time of rank 0 as standard + self.merge_recode() + self.visualise_record() + + +recorder = Recorder() +atexit.register(recorder.exit_worker) diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c68aa4009bb7c6a2ccebaa45824e0e069bcbb8 --- /dev/null +++ b/colossalai/utils/tensor_detector/__init__.py @@ -0,0 +1 @@ +from .tensor_detector import TensorDetector diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..840dc8f4eca648f2d8b4ffc286b430003169216e --- /dev/null +++ b/colossalai/utils/tensor_detector/readme.md @@ -0,0 +1,128 @@ +# Tensor Detector + +This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch. + +## Example + +An example is worth than a thousand words. + +The code below defines a simple MLP module, with which we will show you how to use the tool. + +```python +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.mlp = nn.Sequential(nn.Linear(64, 8), + nn.ReLU(), + nn.Linear(8, 32)) + def forward(self, x): + return self.mlp(x) +``` + +And here is how to use the tool. + +```python +from colossalai.utils import TensorDetector + +# create random data +data = torch.rand(64, requires_grad=True).cuda() +data.retain_grad() +# create the module +model = MLP().cuda() +# create the detector +# by passing the model to the detector, it can distinguish module parameters from common tensors +detector = TensorDetector(include_cpu=False, module=model) +detector.detect() + +out = model(data) + +detector.detect() + +loss = out.sum() +loss.backward() + +detector.detect() +``` + +I have made some comments on the right of the output for your understanding. + +Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly. + +**The order of print is not equal to the order the tensor creates, but they are really close.** + +```bash +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (64,) True torch.float32 256 B # data ++ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB ++ mlp.0.bias cuda:0 (8,) True torch.float32 32 B ++ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB ++ mlp.2.bias cuda:0 (32,) True torch.float32 128 B +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 27 +Totle GPU Memery Allocated on cuda:0 is 4.5 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (8,) True torch.float32 32 B # activation ++ Tensor cuda:0 (32,) True torch.float32 128 B # output +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 30 +Totle GPU Memery Allocated on cuda:0 is 5.5 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 () True torch.float32 4 B # loss +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 32 +Totle GPU Memery Allocated on cuda:0 is 6.0 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad ++ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad() ++ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B ++ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB ++ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B + +- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB +- mlp.0.bias cuda:0 (8,) True torch.float32 32 B +- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB +- mlp.2.bias cuda:0 (32,) True torch.float32 128 B +- Tensor cuda:0 (64,) True torch.float32 256 B +- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 34 +Totle GPU Memery Allocated on cuda:0 is 10.0 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (64,) False torch.float32 256 B ++ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB ++ Tensor cuda:0 (8,) False torch.float32 32 B ++ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB ++ Tensor cuda:0 (32,) False torch.float32 128 B +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 36 +Totle GPU Memery Allocated on cuda:0 is 14.0 KB +------------------------------------------------------------------------------------------------------------ +``` + +## Reference + + This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py + and https://github.com/Oldpan/Pytorch-Memory-Utils + diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a8186f76834c1eec3fdb28c7b6e12dfd6e65260f --- /dev/null +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -0,0 +1,179 @@ +import gc +import inspect +import torch +import torch.nn as nn +from typing import Optional +from collections import defaultdict + +LINE_WIDTH = 108 +LINE = '-' * LINE_WIDTH + '\n' + + +class TensorDetector(): + + def __init__(self, + show_info: bool = True, + log: str = None, + include_cpu: bool = False, + module: Optional[nn.Module] = None): + """This class is a detector to detect tensor on different devices. + + Args: + show_info (bool, optional): whether to print the info on screen, default True. + log (str, optional): the file name to save the log. Defaults to None. + include_cpu (bool, optional): whether to detect tensor on cpu, default False. + module (Optional[:class:`nn.Module`]): when sending an ``nn.Module`` object, + the detector can name the tensors detected better. + """ + self.show_info = show_info + self.log = log + self.include_cpu = include_cpu + self.tensor_info = defaultdict(list) + self.saved_tensor_info = defaultdict(list) + self.order = [] + self.detected = [] + self.devices = [] + self.info = "" + + self.module = module + if isinstance(module, nn.Module): + # if module is an instance of nn.Module, we can name the parameter with its real name + for name, param in module.named_parameters(): + self.tensor_info[id(param)].append(name) + self.tensor_info[id(param)].append(param.device) + self.tensor_info[id(param)].append(param.shape) + self.tensor_info[id(param)].append(param.requires_grad) + self.tensor_info[id(param)].append(param.dtype) + self.tensor_info[id(param)].append(self.get_tensor_mem(param)) + + def get_tensor_mem(self, tensor): + # calculate the memory occupied by a tensor + memory_size = tensor.element_size() * tensor.storage().size() + if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None: + grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size() + memory_size += grad_memory_size + return self.mem_format(memory_size) + + def mem_format(self, real_memory_size): + # format the tensor memory into a reasonal magnitude + if real_memory_size >= 2**30: + return str(real_memory_size / (2**30)) + ' GB' + if real_memory_size >= 2**20: + return str(real_memory_size / (2**20)) + ' MB' + if real_memory_size >= 2**10: + return str(real_memory_size / (2**10)) + ' KB' + return str(real_memory_size) + ' B' + + def collect_tensors_state(self): + for obj in gc.get_objects(): + if torch.is_tensor(obj): + # skip cpu tensor when include_cpu is false and the tensor we have collected before + if (not self.include_cpu) and obj.device == torch.device('cpu'): + continue + self.detected.append(id(obj)) + # skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch + if id(obj) not in self.tensor_info: + + name = type(obj).__name__ + # after backward, we want to update the records, to show you the change + if isinstance(self.module, nn.Module) and name == 'Parameter': + if obj.grad is not None: + # with grad attached + for par_name, param in self.module.named_parameters(): + if param.requires_grad and param.grad.equal(obj.grad): + name = par_name + ' (with grad)' + else: + # with no grad attached + # there will be no new paramters created during running + # so it must be in saved_tensor_info + continue + # we can also marked common tensors as tensor(with grad) + if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): + if obj.grad is not None: + name = name + ' (with grad)' + # in fact, common tensor have no grad + # unless you set retain_grad() + if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]: + continue + + self.tensor_info[id(obj)].append(name) + self.tensor_info[id(obj)].append(obj.device) + self.tensor_info[id(obj)].append(obj.shape) + self.tensor_info[id(obj)].append(obj.requires_grad) + self.tensor_info[id(obj)].append(obj.dtype) + self.tensor_info[id(obj)].append(self.get_tensor_mem(obj)) + # recorded the order we got the tensor + # by this we can guess the tensor easily + # it will record every tensor updated this turn + self.order.append(id(obj)) + # recorded all different devices + if obj.device not in self.devices: + self.devices.append(obj.device) + + def print_tensors_state(self): + template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' + self.info += LINE + self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') + self.info += '\n' + self.info += LINE + + # if a tensor updates this turn, and was recorded before + # it should be updated in the saved_tensor_info as well + outdated = [x for x in self.saved_tensor_info.keys() if x in self.order] + minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected] + minus = outdated + minus + if len(self.order) > 0: + for tensor_id in self.order: + self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5])) + self.info += '\n' + if len(self.order) > 0 and len(minus) > 0: + self.info += '\n' + if len(minus) > 0: + for tensor_id in minus: + self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5])) + self.info += '\n' + # deleted the updated tensor + self.saved_tensor_info.pop(tensor_id) + + # trace where is the detect() + locate_info = inspect.stack()[2] + locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) + + self.info += LINE + self.info += f"Detect Location: {locate_msg}\n" + for device in self.devices: + if device == torch.device('cpu'): + continue + gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) + self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n" + self.info += LINE + self.info += '\n\n' + if self.show_info: + print(self.info) + if self.log is not None: + with open(self.log + '.log', 'a') as f: + f.write(self.info) + + def detect(self, include_cpu=False): + self.include_cpu = include_cpu + self.collect_tensors_state() + self.print_tensors_state() + self.saved_tensor_info.update(self.tensor_info) + self.tensor_info.clear() + self.order = [] + self.detected = [] + self.info = "" + + def close(self): + self.saved_tensor_info.clear() + self.module = None diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b61f4a5ef1148b74962f349cd24d852da899737 --- /dev/null +++ b/colossalai/utils/timer.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import time +from typing import Tuple +from .cuda import synchronize + + +class Timer: + """A timer object which helps to log the execution times, and provides different tools to assess the times. + """ + + def __init__(self): + self._started = False + self._start_time = time.time() + self._elapsed = 0 + self._history = [] + + @property + def has_history(self): + return len(self._history) != 0 + + @property + def current_time(self) -> float: + synchronize() + return time.time() + + def start(self): + """Firstly synchronize cuda, reset the clock and then start the timer. + """ + self._elapsed = 0 + synchronize() + self._start_time = time.time() + self._started = True + + def lap(self): + """lap time and return elapsed time + """ + return self.current_time - self._start_time + + def stop(self, keep_in_history: bool = False): + """Stop the timer and record the start-stop time interval. + + Args: + keep_in_history (bool, optional): Whether does it record into history + each start-stop interval, defaults to False. + Returns: + int: Start-stop interval. + """ + synchronize() + end_time = time.time() + elapsed = end_time - self._start_time + if keep_in_history: + self._history.append(elapsed) + self._elapsed = elapsed + self._started = False + return elapsed + + def get_history_mean(self): + """Mean of all history start-stop time intervals. + + Returns: + int: Mean of time intervals + """ + return sum(self._history) / len(self._history) + + def get_history_sum(self): + """Add up all the start-stop time intervals. + + Returns: + int: Sum of time intervals. + """ + return sum(self._history) + + def get_elapsed_time(self): + """Return the last start-stop time interval. + + Returns: + int: The last time interval. + + Note: + Use it only when timer is not in progress + """ + assert not self._started, 'Timer is still in progress' + return self._elapsed + + def reset(self): + """Clear up the timer and its history + """ + self._history = [] + self._started = False + self._elapsed = 0 + + +class MultiTimer: + """An object contains multiple timers. + + Args: + on (bool, optional): Whether the timer is enabled. Default is True. + """ + + def __init__(self, on: bool = True): + self._on = on + self._timers = dict() + + def start(self, name: str): + """Start namely one of the timers. + + Args: + name (str): Timer's key. + """ + if self._on: + if name not in self._timers: + self._timers[name] = Timer() + return self._timers[name].start() + + def stop(self, name: str, keep_in_history: bool): + """Stop namely one of the timers. + + Args: + name (str): Timer's key. + keep_in_history (bool): Whether does it record into history each start-stop interval. + """ + if self._on: + return self._timers[name].stop(keep_in_history) + else: + return None + + def get_timer(self, name): + """Get timer by its name (from multitimer) + + Args: + name (str): Timer's key. + Returns: + :class:`colossalai.utils.Timer`: Timer with the name you give correctly. + """ + return self._timers[name] + + def reset(self, name=None): + """Reset timers. + + Args: + name (str, optional): If name is designated, the named timer will be reset + and others will not, defaults to None. + """ + if self._on: + if name is not None: + self._timers[name].reset() + else: + for timer in self._timers: + timer.reset() + + def is_on(self): + return self._on + + def set_status(self, mode: bool): + self._on = mode + + def __iter__(self) -> Tuple[str, Timer]: + for name, timer in self._timers.items(): + yield name, timer diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..098ccbb45c5a5a4714fbd3785a070f5d94631f17 --- /dev/null +++ b/colossalai/zero/__init__.py @@ -0,0 +1,41 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from colossalai.logging import get_dist_logger +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 + +from ..nn.optimizer.zero_optimizer import ZeroOptimizer + + +def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, + optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: + """ + A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading + + :param model: Your model object + :type model: :class:`torch.nn.Module` + :param optimizer_config: Your optimizer object + :type optimizer_config: :class:`dict` + + :return: (model, optimizer) + :rtype: Tuple + """ + + logger = get_dist_logger('convert_to_zero_v2') + + logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) + if optimizer_config is None: + optimizer_config = dict() + logger.info(f'model_config is {model_config}', ranks=[0]) + if model_config is None: + model_config = dict() + + zero_model = ShardedModelV2(model, **model_config) + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) + return zero_model, zero_optimizer + + +__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] diff --git a/colossalai/zero/init_ctx/__init__.py b/colossalai/zero/init_ctx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6f81566a9de2d83561fe7d91f9052244b286b8 --- /dev/null +++ b/colossalai/zero/init_ctx/__init__.py @@ -0,0 +1,3 @@ +from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator + +__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator'] diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..572ddd9e4e3fdf9b6f2c18d69b3b57399abb320a --- /dev/null +++ b/colossalai/zero/init_ctx/init_context.py @@ -0,0 +1,266 @@ +import contextlib +import functools +from typing import Optional +from contextlib import AbstractContextManager + +import torch +import torch.nn as nn +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.logging import get_dist_logger +from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.zero.sharded_param import ShardedParamV2 +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses + + +class ZeroContextConfig(object): + """The configuration used to control zero context initialization. + + Args: + target_device (torch.device): The device where param data are after exiting the context. + replicated (bool, optional): Whether the param is replicated across data parallel group. + Some parameters are not replicated, e.g. parameters in MOE experts. + shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + """ + + def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False): + super().__init__() + + if shard_param: + assert replicated, "Non-replicated parameters can't be sharded." + + # replicated no-shard parameters should locate in cuda, since we will broadcast them soon + if replicated and not shard_param: + assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda." + + self.target_device = target_device + self.is_replicated: bool = replicated + self.shard_param: bool = shard_param + + +class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): + """A context to initialize model. + + 1. Convert the model to fp16. + 2. The paramaters of the module are adapted to type ShardedParameter. + 3. Shard the param and grad according to flags. + + Args: + target_device (torch.device): The device where param data are after exiting the context. + shard_strategy (BaseShardStrategy): Shard strategy instance. + seed (int, optional): Random seed for weight initialization + shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16. + model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). + """ + + def __init__(self, + target_device: torch.device, + shard_strategy: BaseShardStrategy, + seed: int = 2**10 - 1, + shard_param: bool = False, + default_dtype: Optional[torch.dtype] = None, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): + + super().__init__(default_dtype=default_dtype) + self.shard_strategy = shard_strategy + self.param_list = [] + self.model_numel_tensor = model_numel_tensor + self.seed = seed + self.dp_process_group = gpc.get_group(ParallelMode.DATA) + + self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param) + + ZeroContextMgr().current_context = self + + self.param_numel = {} + self.top_module = None + + @property + def target_device(self): + return self.config.target_device + + @property + def is_replicated(self): + return self.config.is_replicated + + @property + def shard_param(self): + return self.config.shard_param + + @staticmethod + def calc_fanin_fanout(tensor: torch.Tensor): + """We use this function to substitute fan-in and fan-out calculation in torch.nn.init. + This can help us get correct fan-in and fan-out for sharded tensor. + """ + assert isinstance(tensor, nn.Parameter), "Sharded tensor initilization is only allowed for paramters" + + # get correct shape of input tensor + if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded: + tensor_shape = tensor.shape + else: + tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape + + dimensions = len(tensor_shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + num_input_fmaps = tensor_shape[1] + num_output_fmaps = tensor_shape[0] + receptive_field_size = 1 + if dimensions > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor_shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + self.logger = get_dist_logger("ZeroInitContext") + + # substitute fan-in and fan-out calculation + self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out + nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout + + self.module_load_from_state_dict = nn.Module._load_from_state_dict + shard_strategy = self.shard_strategy if self.config.shard_param else None + nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict, + shard_strategy=shard_strategy) + self.module_state_dict = nn.Module.state_dict + nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict, + shard_strategy=shard_strategy, + state_dict_func=self.module_state_dict, + process_group=self.dp_process_group) + + # reserve rng states + self.cpu_rng_state = torch.get_rng_state() + self.cuda_rng_state = torch.cuda.get_rng_state() + + # set new seed for initialization, since we initialize sharded tensor separately + # we don't want all processes have the same seed + # otherwise all sharded tensors are same after init + offset = self.seed + 1 # we want to have more 1 in binary format seed + torch.manual_seed(self.seed + offset * dist.get_rank()) + + def _post_context_exec(self): + """The callback function when exiting context. + """ + # broadcast replicated no-shard parameters + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in self.param_list: + assert hasattr(param, 'colo_attr') + if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: + dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) + param.colo_attr.set_data_none() + + del self.param_list + + nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout + nn.Module.load_state_dict = self.module_load_from_state_dict + nn.Module.state_dict = self.module_state_dict + torch.set_rng_state(self.cpu_rng_state) + torch.cuda.set_rng_state(self.cuda_rng_state) + + params = frozenset(self.top_module.parameters()) + for param in self.param_numel.keys(): + if param not in params: + self.param_numel[param] = 0 + self.model_numel_tensor.fill_(sum(self.param_numel.values())) + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + NOTE() The module may be passed to this function multiple times. + """ + self.top_module = module + + def half_fn(t: torch.Tensor): + return t.half() if t.is_floating_point() else t + + for param in module.parameters(recurse=False): + # avoid adapting a param to ShardedParam twice + if hasattr(param, 'colo_attr'): + continue + + self.param_numel[param] = param.numel() + + # convert parameters to half + param_half = half_fn(param) + param.data = param_half + if param.grad is not None: + grad_half = half_fn(param.grad) + param.grad.data = grad_half + + # move torch parameters to the target device + target_device = self.target_device + param.data = param.data.to(target_device) + if param.grad is not None: + param.grad = param.grad.to(target_device) + + param.colo_attr = ShardedParamV2(param, set_data_none=True) + + if self.shard_param: + self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) + + param.data = param.colo_attr.data_payload # set param.data to payload + + # mark whether the param is replicated + param.colo_attr.is_replicated = self.is_replicated + + # mark whether the param should keep not sharded + # if True, the param is used as Zero stage 2 + param.colo_attr.keep_not_shard = not self.shard_param + + self.param_list.append(param) + + # We must cast buffers + # If we use BN, buffers may be on CPU and Float + # We must cast them + for buffer in module.buffers(recurse=False): + buffer.data = buffer.data.to(device=torch.cuda.current_device()) + buffer.data = cast_tensor_to_fp16(buffer.data) + + +class ZeroContextMgr(metaclass=SingletonMeta): + current_context: Optional[ZeroInitContext] = None + + @contextlib.contextmanager + def hijack_context_config(self, **kwargs): + if self.current_context is None: + yield + else: + old_config = self.current_context.config + self.current_context.config = ZeroContextConfig(**kwargs) + yield + self.current_context.config = old_config + + +def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: + return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), + replicated=is_replicated, + shard_param=False) + + +def no_shard_zero_decrator(is_replicated: bool = True): + + def _wrapper(init_func): + + def _no_shard(*args, **kwargs): + with no_shard_zero_context(is_replicated): + ret = init_func(*args, **kwargs) + return ret + + return _no_shard + + return _wrapper diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5d63a7e768a470b609ccd185012864752cb432 --- /dev/null +++ b/colossalai/zero/shard_utils/__init__.py @@ -0,0 +1,5 @@ +from .base_shard_strategy import BaseShardStrategy +from .bucket_tensor_shard_strategy import BucketTensorShardStrategy +from .tensor_shard_strategy import TensorShardStrategy + +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/shard_utils/base_shard_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2f4c9f665993b4304d19296f41ba89f5af49d5 --- /dev/null +++ b/colossalai/zero/shard_utils/base_shard_strategy.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch.distributed as dist +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor + + +class BaseShardStrategy(ABC): + + def __init__(self) -> None: + """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. + """ + super().__init__() + + @abstractmethod + def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + pass + + @abstractmethod + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + pass diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bd7cf538e7e809ffafc157ec9e1056b093f1d7 --- /dev/null +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -0,0 +1,46 @@ +from typing import List, Optional + +import torch +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from torch._utils import _flatten_dense_tensors as flatten + +from .tensor_shard_strategy import TensorShardStrategy + + +class BucketTensorShardStrategy(TensorShardStrategy): + """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, + which will fully utilize network bandwidth. + It is especially useful when sub-module contains bias, + since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). + """ + + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + + tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] + if len(tensor_list) == 0: + return + target_device = tensor_list[0].device + dtype = tensor_list[0].dtype + buffer_list: List[torch.Tensor] = [] + tensor_numels = [t.payload.numel() for t in tensor_list] + buffer_size = sum(tensor_numels) + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + for i in range(world_size): + if i == rank: + buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + else: + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + dist.all_gather(buffer_list, buffer_list[rank], group=process_group) + # Move to target device before splitting buffer + # Ensure we utilize maximum PCIE bandwidth + buffer_list = [buffer.to(target_device) for buffer in buffer_list] + offset = 0 + for i, t in enumerate(tensor_list): + gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] + gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) + t.payload_reset(gathered_payload) + t.is_sharded = False + offset += tensor_numels[i] diff --git a/colossalai/zero/shard_utils/commons.py b/colossalai/zero/shard_utils/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..71cef44c177f9c5b728a185d6a5c876196ff7f8c --- /dev/null +++ b/colossalai/zero/shard_utils/commons.py @@ -0,0 +1,22 @@ +import torch +import torch.nn.functional as F +from typing import Tuple + + +def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: + """Return the local shard of a full tensor.""" + # Shard using torch.chunk to match all-gather/reduce-scatter. + chunks = list(torch.flatten(tensor).chunk(world_size)) + while len(chunks) < world_size: + chunks.append(chunks[0].new_empty(0)) + + # Determine number of padding elements. + num_to_pad = chunks[0].numel() - chunks[rank].numel() + assert num_to_pad >= 0, num_to_pad + + shard = torch.zeros_like(chunks[0]) + length = chunks[rank].size(0) + shard_temp = shard[:length] + shard_temp.copy_(chunks[rank]) + + return shard, num_to_pad diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdd95400d82ef0b890193e762fb7c46b8906787 --- /dev/null +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -0,0 +1,58 @@ +from typing import List, Optional + +import torch +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.shard_utils.commons import get_shard +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline + + +class TensorShardStrategy(BaseShardStrategy): + """ + A naive implementation which shard each tensor evenly over all ranks + """ + + def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + for t in tensor_list: + self._shard_tensor(t, process_group) + + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + for t in tensor_list: + self._gather_tensor(t, process_group) + + def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): + """ Shard tensor among processes. + + Args: + t (ShardedTensor): a tensor to be sharded. + process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. + Defaults to None. + """ + if t.is_sharded: + return + if t.payload.device.type == 'cuda': + assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + f" but current cuda device is {get_current_device()}" + sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) + t.payload_reset(sharded_payload) + t.is_sharded = True + + def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): + if not t.is_sharded: + return + target_device = t.device + payload_numel = t.payload.numel() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + + buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) + buffer_list[rank].copy_(t.payload) + + dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) + gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape) + t.payload_reset(gathered_payload) + colo_model_data_tensor_move_inline(t, target_device) + t.is_sharded = False diff --git a/colossalai/zero/sharded_model/__init__.py b/colossalai/zero/sharded_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..725179295c60ae22211a475a33dc9c7885357e6e --- /dev/null +++ b/colossalai/zero/sharded_model/__init__.py @@ -0,0 +1,3 @@ +from .sharded_model_v2 import ShardedModelV2 + +__all__ = ['ShardedModelV2'] \ No newline at end of file diff --git a/colossalai/zero/sharded_model/_utils.py b/colossalai/zero/sharded_model/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85a3ab73dd1b3743a462929c94432ed3491e9c52 --- /dev/null +++ b/colossalai/zero/sharded_model/_utils.py @@ -0,0 +1,77 @@ +from typing import Any, Callable, List, Tuple + +import torch +import torch.nn.functional as F +from typing import Union +from colossalai.gemini.stateful_tensor import StatefulTensor + + +def get_gradient_predivide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +@torch.no_grad() +def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: + """Allocate storage for a tensor.""" + if data.storage().size() == size.numel(): # no need to reallocate + return + assert data.storage().size() == 0 + data.storage().resize_(size.numel()) + + +def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.half() + return tensor + + +def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + + if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: + return tensor.float() + return tensor + + +def apply_to_tensors(x: Any, fn: Callable): + if torch.is_tensor(x): + return fn(x) + elif isinstance(x, list): + return [apply_to_tensors(t, fn) for t in x] + elif isinstance(x, tuple): + return tuple(apply_to_tensors(t, fn) for t in x) + elif isinstance(x, dict): + return {key: apply_to_tensors(val, fn) for key, val in x.items()} + else: + return x + + +def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: + return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn) + + +def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: + """Chunk a given Tensor into num_chunks parts and add any necessary padding.""" + chunks = list(torch.flatten(tensor).chunk(num_chunks)) + # torch.chunk may return fewer than num_chunks chunks, pad accordingly. + num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() + if num_pad_for_partial_chunk > 0: + chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) + if len(chunks) < num_chunks: + chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) + return chunks diff --git a/colossalai/zero/sharded_model/reduce_scatter.py b/colossalai/zero/sharded_model/reduce_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb507382df9eae2d3efa35fdcdcb2704a9256dc --- /dev/null +++ b/colossalai/zero/sharded_model/reduce_scatter.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import os +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved. +if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": + enable_nccl_base_collectives = False +else: + enable_nccl_base_collectives = True + + +class Bucket: + + def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): + self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device) + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + self.output_shard = torch.zeros_like(self.buffer[0]) + + def flush(self) -> None: + """Flush content of the bucket.""" + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: + dist._reduce_scatter_base(self.output_shard[:self.offset], + self.buffer[:, :self.offset].contiguous(), + group=self.group) + else: + dist.reduce_scatter(self.output_shard[:self.offset], + list(self.buffer[:, :self.offset].unbind(0)), + group=self.group) + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.buffer[:, :self.offset].zero_() + self.offset = 0 + self.callbacks.clear() + self.output_shard = torch.zeros_like(self.buffer[0]) + + def alloc(self) -> None: + """Setup the buffers if they are not allocated. + + Using ``setup`` and ``teardown``, we can ensure that the bucket + buffers are only allocated during the backward pass, hence saving more + memory to other parts of the training process, such as the forward pass + for activation memory. + """ + for tensor in [self.buffer, self.output_shard]: + if tensor.storage().size() == 0: + tensor.storage().resize_(tensor.size().numel()) + + def free(self) -> None: + """Tear down the bucket by freeing the memory""" + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + for tensor in [self.buffer, self.output_shard]: + tensor.storage().resize_(0) + + def append(self, tensor_list: List[Tensor], callback_fn: Callable): + # copy data from input_list into bucket + tensor_size = tensor_list[0].numel() + stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size) + offset = self.offset + self.buffer[:, offset:offset + tensor_size].copy_(stacked_input) + self.offset += tensor_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0]) + self.callbacks.append(functools.partial(callback_fn, result_view)) + + +class ReduceScatterBucketer: + """ + Helper for bucketing multiple reduce-scatter operations on small tensors + into larger reduce-scatter ops to improve communication efficiency. + + Usage:: + + bucketer = ReduceScatterBucketer() + bucketer.reduce_scatter_async( + small_tensors, callback_fn=lambda result: print("small") + ) + bucketer.reduce_scatter_async( + big_tensors, callback_fn=lambda result: print("big") + ) + bucketer.reduce_scatter_async( + more_small_tensors, callback_fn=lambda result: print("small2") + ) + bucketer.flush() # callbacks only guaranteed to be called after flush() + # Example output (note that it is out of order, due to bucketing): + # big + # small + # small2 + + Args: + bucket_size_mb (int, Optional): bucket size for communicating. Buckets + are sub-divided based on world_size. Values <= 0 disable bucketing. + """ + + def __init__(self, bucket_size_mb: int = 25): + self.bucket_size_mb = bucket_size_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def reduce_scatter_async( + self, + input_list: List[Tensor], + group: ProcessGroup, + callback_fn: Optional[Callable] = None, + ) -> None: + """ + Reduce-scatter a list of tensors asynchronously, so smaller reductions + can be bucketed together. The given callback (``callback_fn``) will be + called with the reduced result at some later time. Call ``flush()`` to + force all queued ops and callbacks to be executed. + + Note that large inputs will be reduced immediately, and this function + may also flush the relevant bucket to make room for ``input_list``. + + Args: + input_list (List[Tensor]): list of tensors to reduce-scatter. List + should contain ``group.size()`` tensors and each tensor should + have identical shape, dtype and device. + group (ProcessGroup): process group for reduction + callback_fn (Callable, Optional): callback function to call after + the reduction executes. Function will be called with a single + argument corresponding to the reduced result. + """ + world_size = group.size() + + assert (len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + + first_input = input_list[0] + first_input_size = first_input.numel() + + bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) + if first_input_size > bucket_shard_size: + # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors) + # input is too big to fit in the bucket, reduce-scatter directly + output = torch.zeros_like(input_list[0]) + if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: + input_flattened = torch.cat(input_list) + dist._reduce_scatter_base(output, input_flattened, group=group) + else: + # fallback + dist.reduce_scatter(output, input_list, group=group) + if callback_fn is not None: + callback_fn(output) + return + + bucket = self._get_bucket(first_input, group) + if first_input_size > bucket.buffer.size(1) - bucket.offset: + # not enough space remaining in bucket, flush it now + bucket.flush() + bucket.append(input_list, callback_fn) + + @torch.no_grad() + def flush(self) -> None: + """Reduce-scatter any partial buckets.""" + for bucket in self.buckets.values(): + bucket.flush() + + @torch.no_grad() + def free(self) -> None: + """Free buffers from all buckets.""" + for bucket in self.buckets.values(): + bucket.free() + + @functools.lru_cache() + def _get_shard_size(self, element_size: int, num_shards: int) -> int: + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_size_mb * MB / element_size + return int(bucket_size // num_shards) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) + world_size = group.size() + shard_size = self._get_shard_size(tensor.element_size(), world_size) + self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group) + self.buckets[key].alloc() + return self.buckets[key] diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3a619980accff2adb489cf5f4b719b214fe213 --- /dev/null +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -0,0 +1,570 @@ +import functools +import itertools +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Iterator, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector +from colossalai.gemini.ophooks import register_ophooks_recursively +from colossalai.gemini.paramhooks import BaseParamHookMgr +from colossalai.gemini.stateful_tensor import TensorState +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.logging import get_dist_logger +from colossalai.utils import disposable, get_current_device +from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer +from colossalai.zero.utils import ZeroHook + +from ._utils import ( + cast_float_arguments, + cast_tensor_to_fp16, + cast_tensor_to_fp32, + chunk_and_pad, + free_storage, + get_gradient_predivide_factor, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +class ShardedModelV2(nn.Module): + """ + A wrapper for the PyTorch module shards the model parameters among multiple GPU memory. + Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward + passes can be executed with limited CUDA memory budget. + + Note: + You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``. + Note: + Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter, + if you enable ``reuse_fp16_shard``. + + Args: + module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. + shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. + process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. + reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. + Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. + reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. + fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. + tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. + If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. + If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. + Note that 'auto' policy can only work well when no other processes use CUDA during your training. + Defaults to 'cuda'. + gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. + reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. + Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. + In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). + We find that PyTorch's optimizers don't support mixed precision, + so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. + """ + + def __init__(self, + module: nn.Module, + shard_strategy: BaseShardStrategy, + process_group: Optional[ProcessGroup] = None, + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + fp32_reduce_scatter: bool = False, + tensor_placement_policy: str = 'cuda', + gradient_predivide_factor: Optional[float] = 1.0, + reuse_fp16_shard: bool = False, + *args, + **kwargs): + assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' + super().__init__() + self.logger = get_dist_logger() + + # We force users to use ZeroInitContext + for submodule in module.modules(): + sharded_cnt = 0 + unshard_cnt = 0 + for param in submodule.parameters(recurse=False): + assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' + if param.colo_attr.param_is_sharded: + sharded_cnt += 1 + else: + unshard_cnt += 1 + assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param' + submodule.param_is_sharded = (sharded_cnt > 0) + + self.sharded_params = [] + self.unshard_params = [] + for param in module.parameters(): + if param.colo_attr.param_is_sharded: + self.sharded_params.append(param) + else: + self.unshard_params.append(param) + + self.module = module + self.process_group = process_group or gpc.get_group(ParallelMode.DATA) + self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group + self.world_size = dist.get_world_size(self.process_group) + self.rank = dist.get_rank(self.process_group) + self.shard_strategy = shard_strategy + + self._use_memory_tracer = tensor_placement_policy == 'auto' + if self._use_memory_tracer: + self._memstats_collector = MemStatsCollector() + self._start_collect_memstats = disposable(self._memstats_collector.start_collection) + self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) + else: + self._memstats_collector = None + self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( + tensor_placement_policy)(mem_stats_collector=self._memstats_collector) + + if 'warmup_non_model_data_ratio' in kwargs: + if tensor_placement_policy != 'auto': + self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement') + else: + ratio = kwargs['warmup_non_model_data_ratio'] + self._tensor_placement_policy._warmup_non_model_data_ratio = ratio + self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement') + + self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) + param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] + self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) + + # Register hooks + self._ophook_list = [ + ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) + ] + register_ophooks_recursively(self.module, self._ophook_list) + self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) + self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) + + self.fp32_reduce_scatter = fp32_reduce_scatter + self._cpu_offload: bool = tensor_placement_policy != 'cuda' + for param in module.parameters(): + # Init `offload_grad` + param.colo_attr.offload_grad = self._cpu_offload + + # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem + # So we use 1.0 as the default gradient_predivide_factor + # However, if you set gradient_predivide_factor to None, we will set + # gradient_predivide_factor to a value >= 1.0 automatically + self.gradient_predivide_factor: float = gradient_predivide_factor if \ + gradient_predivide_factor is not None else \ + get_gradient_predivide_factor(self.world_size) + self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor + + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) + self._require_backward_grad_sync: bool = True + + self._cuda_margin_space = 0 + self.reuse_fp16_shard = reuse_fp16_shard + + # record whether gradients have inf or nan + self.overflow_counter = 0 + + def adjust_stateful_tensor_layout(self) -> None: + self._stateful_tensor_mgr.adjust_layout() + + @property + def use_memory_tracer(self): + return self._use_memory_tracer + + @property + def cuda_margin_space(self): + return self._cuda_margin_space + + @property + def cpu_offload(self): + return self._cpu_offload + + def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: + """ + dummy memory tracer collected infomation to a file. + try: + # forward: model(inputs) + # backward: optimizer.backward() + except Exception as e: + model.dump_memory_stats() + exit(0) + """ + if self._use_memory_tracer: + self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0]) + if gpc.get_global_rank() == 0: + with open(filename, 'w+') as f: + f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') + f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') + f.write('CUDA model data (GB)\n') + f.write('\n') + f.write('CUDA non model data (GB)\n') + f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) + f.write('CPU non model data (GB)\n') + f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) + f.write('\n') + + def _pre_forward_operations(self, *args): + # the operation will affect the memory tracer behavior in ZeroHook + if self._memstats_collector: + self._start_collect_memstats() + + for p in self.module.parameters(): + if hasattr(p, 'colo_attr'): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + self._stateful_tensor_mgr.start_iter() + + def _post_forward_operations(self): + for p in self.module.parameters(): + if hasattr(p, 'colo_attr'): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + self._pre_forward_operations(*args) + args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) + outputs = self.module(*args, **kwargs) + self._post_forward_operations() + return outputs + + def backward(self, loss): + loss.backward() + self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() + + def backward_by_grad(self, tensor, grad): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() + + def _update_memstats(self): + if self._memstats_collector: + self._finish_collect_memstats() + # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. + # the way to calculate margin space is based on the assumption that + # model data is fixed in cuda during training. + # cuda margin space can be used to store OS. + self._cuda_margin_space = colo_device_memory_capacity( + get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + + @torch.no_grad() + def _post_backward_operations(self) -> None: + """ + The method includes operations required to be processed after backward + 1. update memory tracer. + 2. flush the gradient in buckets. Reducing partial gradients in each process. + 3. shard tensors not dealed in the zero hook + 4. move sharded param grad payload to param.grad + """ + # 1. update memory tracer. + self._update_memstats() + + # 2. flush the gradient in buckets. Reducing partial gradients in each process. + if self._require_backward_grad_sync: + # Flush any unreduced buckets in the post_backward stream. + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + self.reducer.free() + + # 3. shard tensors not dealed in the zero hook + tensor_list = [] + for p in self.sharded_params: + if not p.colo_attr.param_is_sharded: + tensor_list.append(p.colo_attr.sharded_data_tensor) + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + p.colo_attr.set_data_none() + self.shard_strategy.shard(tensor_list, self.process_group) + + # 4. set all parameters' grad to None + for p in self.module.parameters(): + if not p.requires_grad: + continue + # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass. + # NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group. + # If _require_backward_grad_sync is True, + # p.grad remains the accumulated unsharded gradient from prior no-sync passes. + # We also allows to interleave no-sync pass with sync passes, if desired. + if not self._require_backward_grad_sync: + continue + + p.grad = None + + @torch.no_grad() + def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: + """ + At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the + full gradient for the local batch. The reduce-scatter op will save + a single shard of the summed gradient across all + GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example:: + + before reduce_scatter: + param.grad (GPU #0): [1, 2, 3, 4] + param.grad (GPU #1): [5, 6, 7, 8] + + after reduce_scatter: + param.grad (GPU #0): [6, 8] # 1+5, 2+6 + param.grad (GPU #1): [10, 12] # 3+7, 4+8 + + The local GPU's ``optim.step`` is responsible for updating a single + shard of params, also corresponding to the current GPU's rank. This + alignment is created by `param.colo_attr.grad`, which ensures that + the local optimizer only sees the relevant parameter shard. + """ + if grad is None: + return + assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' + if not self._require_backward_grad_sync: + return + # used to cheat Pytorch, since we can't return None + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + # As torch didn't allow modifying grad in hook, we make a copy + grad = grad.clone() + if param.colo_attr.is_replicated: + self._reduce_scatter_handler(param, grad) + else: + self._save_grad(param, grad) + return empty_grad + + def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + if self.fp32_reduce_scatter: + grad.data = grad.data.to(param.dtype) + if self.gradient_predivide_factor > 1.0: + # Average grad by world_size for consistency with PyTorch DDP. + grad.data.div_(self.gradient_predivide_factor) + if self.world_size > 1: + grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) + self.reducer.reduce_scatter_async(grad_chunks, + group=self.reduce_scatter_process_group, + callback_fn=functools.partial(self._reduce_scatter_callback, param)) + else: + self._reduce_scatter_callback(param, grad) + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + assert isinstance(reduced_grad, + torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" + reduced_grad.data = reduced_grad.data.contiguous().view(-1) + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.data.div_(self.gradient_postdivide_factor) + self._save_grad(param, reduced_grad) + + # FIXME(ver217): refactor the below line when impl eviction policy + def _save_grad(self, param: Parameter, grad: torch.Tensor): + + # record whether we have overflow + self.overflow_counter += torch.isinf(grad).any().item() + self.overflow_counter += torch.isnan(grad).any().item() + + # move gradient to cpu + if param.colo_attr.offload_grad: + colo_model_data_move_to_cpu(grad) + + if self.reuse_fp16_shard: + # make parameters point to gradient + + assert param.colo_attr.saved_grad.is_null( + ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' + + param.colo_attr.grad_payload_reset(grad.data) + # release the memory of param + # we set a false None for parameter's payload + # so we can get paramter's device and dtype later in optimizer + param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype)) + + if param.colo_attr.is_replicated: + param.colo_attr.sharded_data_tensor.is_sharded = True + else: + + fp32_grad = cast_tensor_to_fp32(grad) + + if param.colo_attr.saved_grad.is_null(): + param.colo_attr.grad_payload_reset(fp32_grad) + else: + param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload)) + + # keep saved_grad in HOLD state + param.colo_attr.saved_grad.trans_state(TensorState.HOLD) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + return self.module.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + return self.module.named_parameters(prefix, recurse) + + def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': + return self._colo_state_dict(destination, + prefix, + keep_vars, + shard_strategy=self.shard_strategy, + state_dict_func=nn.Module.state_dict, + module_to_load=self.module, + sharded_params=self.sharded_params, + process_group=self.process_group) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None: + for name, p in self.named_parameters(): + if name in state_dict: + p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, + device=p.colo_attr.data_payload.device)) + # Force re-shard + p.colo_attr.sharded_data_tensor.is_sharded = False + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) + elif strict: + raise RuntimeError(f'Missing key in state_dict: {name}') + + def _colo_state_dict(self, + destination=None, + prefix='', + keep_vars=False, + shard_strategy: Optional[BaseShardStrategy] = None, + state_dict_func=None, + module_to_load=None, + sharded_params=[], + process_group=None) -> 'OrderedDict[str, torch.Tensor]': + if len(sharded_params) == 0: + for param in self.parameters(): + if param.colo_attr.param_is_sharded: + sharded_params.append(param) + if shard_strategy is not None: + shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: + p.data = p.colo_attr.data_payload + module_to_load = module_to_load or self + gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars) + gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()} + if shard_strategy is not None: + shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: + p.colo_attr.set_data_none() + return gathered_state_dict + + def _colo_load_from_state_dict(self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + shard_strategy=None): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if hasattr(param, 'colo_attr'): + param.colo_attr.data_payload_reset( + input_param.to(dtype=param.colo_attr.data_payload.dtype, + device=param.colo_attr.data_payload.device)) + if shard_strategy is not None: + # Force re-shard + param.colo_attr.sharded_data_tensor.is_sharded = False + shard_strategy.shard([param.colo_attr.sharded_data_tensor]) + else: + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format( + key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def __getitem__(self, idx: int): + assert isinstance(self.module, nn.ModuleList) + return self.module[idx] + + def __len__(self): + assert isinstance(self.module, nn.ModuleList) + return len(self.module) + + def __iter__(self): + assert isinstance(self.module, nn.ModuleList) + return iter(self.module) diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/sharded_model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69f5a23ac92057f76d8c49bd2e367e520470dbc8 --- /dev/null +++ b/colossalai/zero/sharded_model/utils.py @@ -0,0 +1,19 @@ +import torch +from colossalai.zero.sharded_model import ShardedModelV2 + +import copy + + +def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): + """ + copy param of the ShardedModelV2 to other_model. + Note the other_model has to be the same as self. + """ + for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): + assert hasattr(zero_param, 'colo_attr') + shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded + if shard_flag: + sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) + param.data = copy.deepcopy(zero_param.colo_attr.data_payload) + if shard_flag: + sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/zero/sharded_optim/__init__.py b/colossalai/zero/sharded_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30c26fb75f30d41026d11f97a3afdf4123154c98 --- /dev/null +++ b/colossalai/zero/sharded_optim/__init__.py @@ -0,0 +1,4 @@ +from .low_level_optim import LowLevelZeroOptimizer +from .sharded_optim_v2 import ShardedOptimizerV2 + +__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a839a5705c35762e6f9f6cd36335fc7b2809453 --- /dev/null +++ b/colossalai/zero/sharded_optim/_utils.py @@ -0,0 +1,261 @@ +import math + +import torch +import torch.distributed as dist +from torch._six import inf +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import is_model_parallel_parameter + + +def flatten(input_): + return _flatten_dense_tensors(input_) + + +def unflatten(flat, tensors): + return _unflatten_dense_tensors(flat, tensors) + + +def count_numel(tensor_list): + res = 0 + for tensor in tensor_list: + res += tensor.numel() + return res + + +def calculate_padding(numel, unit_size): + remainder = numel % unit_size + return unit_size - remainder if remainder else remainder + + +def shuffle_by_round_robin(tensor_list, num_partitions): + partitions = dict() + + for tensor_idx, tensor in enumerate(tensor_list): + partition_to_go = tensor_idx % num_partitions + if partition_to_go not in partitions: + partitions[partition_to_go] = [] + partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx)) + + partitions_count = len(partitions) + new_tensor_list = [] + tensor_index_mapping = dict() + + for partition_id in range(partitions_count): + partition_tensors = partitions[partition_id] + for item in partition_tensors: + tensor_index_mapping[item['index']] = len(new_tensor_list) + new_tensor_list.append(item['tensor']) + + return new_tensor_list, tensor_index_mapping + + +# create a flat tensor aligned at the alignment boundary +def flatten_dense_tensors_with_padding(tensor_list, unit_size): + num_elements = count_numel(tensor_list) + padding = calculate_padding(num_elements, unit_size=unit_size) + + if padding > 0: + pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + else: + padded_tensor_list = tensor_list + + return flatten(padded_tensor_list) + + +def is_nccl_aligned(tensor): + return tensor.data_ptr() % 4 == 0 + + +def get_grad_accumulate_object(tensor): + """ + Return the AccumulateGrad of the input tensor + """ + + # grad_fn reference: + # https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463 + # expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand + # + # `next_functions` will return the backward graph where + # the first element is the AccumulateGrad of the leaf nodes. + # we want to get the AccumulateGrad of the input tensor instead of the leaf + # node in the whole computation graph. + # Therefore, we call expand_as to create a dummy graph + # where tensor_tmp and tensor indeed point to the same object. + # You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr()) + tensor_tmp = tensor.expand_as(tensor) + grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0] + return grad_acc_obj + + +def split_half_float_double(tensor_list): + dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] + buckets = [] + for i, dtype in enumerate(dtypes): + bucket = [t for t in tensor_list if t.type() == dtype] + if bucket: + buckets.append(bucket) + return buckets + + +def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA): + """ + Reduce the tensor in the data parallel process group + + :param tensor: A tensor object to reduce/all-reduce + :param dtype: The data type used in communication + :param dst_rank: The source rank for reduce. If dst_rank is None, + :param parallel_mode: Communication parallel mode + all-reduce will be used instead of reduce. Default is None. + + :type tensor: torch.Tensor + :type dtype: torch.dtype, optional + :type dst_rank: int, optional + :type parallel_mode: ParallelMode, optional + """ + # use the original dtype + if dtype is None: + dtype = tensor.dtype + + # cast the data to specified dtype for reduce/all-reduce + if tensor.dtype != dtype: + tensor_to_reduce = tensor.to(dtype) + else: + tensor_to_reduce = tensor + + world_size = gpc.get_world_size(parallel_mode) + group = gpc.get_group(parallel_mode) + tensor_to_reduce.div_(world_size) + + # if rank is None, all reduce will be used + # else, reduce is used + use_all_reduce = dst_rank is None + + if use_all_reduce: + dist.all_reduce(tensor_to_reduce, group=group) + else: + ranks_in_group = gpc.get_ranks_in_group(parallel_mode) + global_rank = ranks_in_group[dst_rank] + dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group) + + # recover the original dtype + if tensor.dtype != dtype and tensor is not tensor_to_reduce: + local_rank = gpc.get_local_rank(parallel_mode) + if use_all_reduce or dst_rank == local_rank: + tensor.copy_(tensor_to_reduce) + + return tensor + + +def has_inf_or_nan(tensor): + try: + # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as tensor + # (which is true for some recent version of pytorch). + tensor_sum = float(tensor.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # tensor_sum = float(tensor.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + return True + return False + + +def release_param_grad(tensor_list): + for tensor in tensor_list: + tensor.grad = None + + +def calculate_global_norm_from_list(norm_list): + """ Compute total from a list of norms + """ + total_norm = 0.0 + for norm in norm_list: + total_norm += norm**2.0 + return math.sqrt(total_norm) + + +def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): + """Clips gradient norm of an iterable of parameters. + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + Returns: + Total norm of the parameters (viewed as a single vector). + """ + + if mp_group is None: + mp_rank = 0 + else: + mp_rank = dist.get_rank(mp_group) + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) + + # Take max across all GPUs. + if mp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0.0 + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if is_model_parallel_parameter(p) or mp_rank == 0: + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) + + if mp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + +def sync_param(flat_tensor, tensor_list): + """ + Synchronize the flattened tensor and unflattened tensor list. When + a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, + a new tensor is created. Thus, the flat tensor and original tensor list do not + share the same memory space. This function will update the tensor list so that + they point to the same value. + + :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit + :param tensor_list: A list of tensors corresponding to the flattened tensor + :type flat_tensor: torch.Tensor + :type tensor_list: List[torch.Tensor] + """ + updated_params = unflatten(flat_tensor, tensor_list) + + # update the tensor data + for p, q in zip(tensor_list, updated_params): + p.data = q.data diff --git a/colossalai/zero/sharded_optim/bookkeeping/__init__.py b/colossalai/zero/sharded_optim/bookkeeping/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcacfabfded39972babff0536cb75b0c2c65506 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/__init__.py @@ -0,0 +1,6 @@ +from .bucket_store import BucketStore +from .gradient_store import GradientStore +from .parameter_store import ParameterStore +from .tensor_bucket import TensorBucket + +__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/sharded_optim/bookkeeping/base_store.py new file mode 100644 index 0000000000000000000000000000000000000000..d4436acaa4bf3269e9116e6cd501a4880aa868be --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/base_store.py @@ -0,0 +1,17 @@ +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class BaseStore: + + def __init__(self, dp_parallel_mode=ParallelMode.DATA): + self._world_size = gpc.get_world_size(dp_parallel_mode) + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + + @property + def world_size(self): + return self._world_size + + @property + def local_rank(self): + return self._local_rank diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2b1bb88b582e53bc8b0df21d9650fd89275d63 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py @@ -0,0 +1,44 @@ +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + +from .base_store import BaseStore + + +class BucketStore(BaseStore): + + def __init__(self, dp_parallel_mode): + super().__init__(dp_parallel_mode) + self._grads = dict() + self._params = dict() + self._num_elements_in_bucket = dict() + + self.reset() + + def num_elements_in_bucket(self, reduce_rank: int = None): + return self._num_elements_in_bucket[reduce_rank] + + def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): + self._num_elements_in_bucket[reduce_rank] += num_elements + + def add_grad(self, tensor, reduce_rank: int = None): + self._grads[reduce_rank].append(tensor) + + def add_param(self, tensor, reduce_rank: int = None): + self._params[reduce_rank].append(tensor) + + def reset(self): + keys = [None] + list(range(self._world_size)) + self._grads = {rank: [] for rank in keys} + self._params = {rank: [] for rank in keys} + self._num_elements_in_bucket = {rank: 0 for rank in keys} + + def reset_by_rank(self, reduce_rank=None): + self._grads[reduce_rank] = [] + self._params[reduce_rank] = [] + self._num_elements_in_bucket[reduce_rank] = 0 + + def get_grad(self, reduce_rank: int = None): + return self._grads[reduce_rank] + + def get_param(self, reduce_rank: int = None): + return self._params[reduce_rank] diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9128a189642e902c6d0d41a0f50e2505844512 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py @@ -0,0 +1,66 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class GradientStore(BaseStore): + + def __init__(self, *args): + super().__init__(*args) + # bookkeeping data structures + self._averaged_gradients = dict() + + # for backward reduction hooks + self._grad_acc_objs = [] + + def add_accumulate_grad_object(self, obj): + """ + Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not + be attached successfully. + + :param obj: An object of :class:`AccumulateGrad` class + :type obj: :class:`AccumulateGrad` + """ + + self._grad_acc_objs.append(obj) + + def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: + """ + Return average gradients of a parameter group + + :param group_id: The index of parameter group + :type group_id: int + + :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. + :rtype: List[torch.Tensor] + """ + + return self._averaged_gradients[group_id] + + def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: + """ + Append an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor: torch.Tensor + + """ + + if group_id in self._averaged_gradients: + self._averaged_gradients[group_id].append(tensor) + else: + self._averaged_gradients[group_id] = [tensor] + + def reset_average_gradients_by_group(self, group_id: int) -> None: + """ + Reset the bookkeeping data structure for averaged gradients to an empty list + + :param group_id: The index of a parameter group + :type group_id: int + """ + + self._averaged_gradients[group_id] = [] diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py new file mode 100644 index 0000000000000000000000000000000000000000..09ebaaf9938cca29afbc38f554d2e0206fa92c91 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py @@ -0,0 +1,96 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class ParameterStore(BaseStore): + + def __init__(self, dp_paralle_mode): + super().__init__(dp_paralle_mode) + # param partitioning data structures + self._fp16_param_to_rank = dict() + self._rank_groupid_to_fp16_param_list = dict() + self._rank_group_id_to_flat_fp16_param = dict() + + # param reduction data structures + self._is_param_reduced = dict() + self._reduced_param = [] + + def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: + """ + Set the mapping between parameter to rank, each parameter should be owned by a rank. + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + :param rank: The rank of which the process is responsible for updating the parameter + :type rank: int + """ + + self._fp16_param_to_rank[tensor] = rank + + def get_param_rank(self, tensor: Tensor) -> int: + """ + Gives the rank which the parameter belongs to + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + """ + return self._fp16_param_to_rank[tensor] + + def belongs_to_current_rank(self, tensor) -> bool: + """ + Check whether a parameter is supposed to be updated by the process of the current rank + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + + :return: True if the parameter should be updated by the current rank. Otherwise false. + :rtype: bool + """ + + tensor_rank = self._fp16_param_to_rank[tensor] + return tensor_rank == self._local_rank + + def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: + if rank not in self._rank_groupid_to_fp16_param_list: + self._rank_groupid_to_fp16_param_list[rank] = dict() + + if group_id not in self._rank_groupid_to_fp16_param_list[rank]: + self._rank_groupid_to_fp16_param_list[rank][group_id] = [] + + self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list) + + def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]: + return self._rank_groupid_to_fp16_param_list[rank][group_id] + + def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None: + if rank not in self._rank_group_id_to_flat_fp16_param: + self._rank_group_id_to_flat_fp16_param[rank] = dict() + + self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor + + def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor: + return self._rank_group_id_to_flat_fp16_param[rank][group_id] + + def is_param_reduced(self, tensor): + return self._is_param_reduced[tensor] + + def set_param_reduction_state(self, tensor, state): + self._is_param_reduced[tensor] = state + + def get_param_reduction_states(self): + return self._is_param_reduced + + def reset_previous_reduced_params(self): + self._reduced_param = [] + + def add_previous_reduced_param(self, tensor): + self._reduced_param.append(tensor) + + def clear_grads_of_previous_reduced_params(self): + if len(self._reduced_param) > 0: + for param in self._reduced_param: + param.grad = None + self.reset_previous_reduced_params() diff --git a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py new file mode 100644 index 0000000000000000000000000000000000000000..b32816a046cd6a156196e84e957e070c4401d555 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py @@ -0,0 +1,53 @@ +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class TensorBucket: + + def __init__(self, size): + self._max_size = size + self._current_size = 0 + self._bucket = [] + + @property + def max_size(self): + return self._max_size + + @property + def current_size(self): + return self._current_size + + def is_full_or_oversized(self): + return self._current_size >= self._max_size + + def is_empty(self): + return len(self._bucket) == 0 + + def add_to_bucket(self, tensor, allow_oversize=False): + tensor_size = tensor.numel() + + if not allow_oversize and self.will_exceed_max_size(tensor_size): + msg = f"The param bucket max size {self._max_size} is exceeded" \ + + f"by tensor (size {tensor_size})" + raise RuntimeError(msg) + + self._bucket.append(tensor) + self._current_size += tensor_size + + def will_exceed_max_size(self, tensor_size): + expected_size = self._current_size + tensor_size + return expected_size > self._max_size + + def get_bucket(self): + return self._bucket + + def empty(self): + self._bucket = [] + self._size = 0 + + def flatten(self): + return _flatten_dense_tensors(self._bucket) + + def unflatten_and_copy(self, flat_tensor): + unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) + for old, new in zip(self._bucket, unflattened_tensor_list): + old.copy_(new) diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..d30b69e7ebfdbf7bdfa77a6f0836c1b57eda9ad9 --- /dev/null +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -0,0 +1,599 @@ +from functools import partial +from itertools import groupby + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils.cuda import get_current_device + +from ._utils import ( + calculate_global_norm_from_list, + compute_norm, + flatten, + get_grad_accumulate_object, + has_inf_or_nan, + reduce_tensor, + release_param_grad, + split_half_float_double, + sync_param, +) +from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket + + +class LowLevelZeroOptimizer(ColossalaiOptimizer): + """Optimizer used for ZeRO-1 and ZeRO-2. + """ + + def __init__( + self, + optimizer: Optimizer, + + # grad scaler config + initial_scale=2**32, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale: int = 2**32, + + # grad clipping + clip_grad_norm=0.0, + verbose=False, + + # communication + reduce_bucket_size=1024 * 1024, + communication_dtype=None, + overlap_communication=False, + + # stage 2 + partition_grad=False, + dp_parallel_mode=ParallelMode.DATA, + mp_parallel_mode=ParallelMode.MODEL, + + # cpu offload + cpu_offload=False, + + # forced dtype + forced_dtype=None): + + # TODO: add support for + # 1. fp16 master weights + # 2. contiguous gradients + # 3. cpu offload + # 4. support when some parameters requires_grad = False + + self._optimizer = optimizer + self._dtype = self._optimizer.param_groups[0]['params'][0].dtype + self._logger = get_dist_logger() + self._verbose = verbose + + # stage 2 + self._partition_grads = partition_grad + + # cpu_offload + self._cpu_offload = cpu_offload + + # get process groups + self._dp_parallel_mode = dp_parallel_mode + self._mp_parallel_mode = mp_parallel_mode + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + self._world_size = gpc.get_world_size(dp_parallel_mode) + + self._dp_group = gpc.get_group(dp_parallel_mode) + if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: + self._mp_group = gpc.get_group(mp_parallel_mode) + else: + self._mp_group = None + + # fp16 and fp32 params for mixed precision training + self._fp16_param_groups = dict() + self._fp32_flat_param_groups_of_current_rank = dict() + + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + + # gradient scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + verbose=verbose) + self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) + + # gradient clipping + self._clip_grad_norm = clip_grad_norm + + if forced_dtype: + for group in self._optimizer.param_groups: + group_params = group['params'] + for param in group_params: + param.data = param.data.to(forced_dtype) + self._dtype = forced_dtype + + # check argument conflict + self._sanity_checks() + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + self._param_store = ParameterStore(self._dp_parallel_mode) + self._grad_store = GradientStore(self._dp_parallel_mode) + self._bucket_store = BucketStore(self._dp_parallel_mode) + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self._optimizer.param_groups): + group_params = param_group['params'] + + # add the fp16 params to fp16_param_groups for bookkeeping + self._fp16_param_groups[group_id] = group_params + + # assign parameters to ranks + # the params in the list are sorted + params_per_rank = self._partition_param_list(group_params) + + # store the mapping between param to rank + # each param should belong to only one rank + for rank, params in enumerate(params_per_rank): + self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params) + for param in params: + self._param_store.set_param_to_rank(param, rank) + + # move to cpu to make room to create the flat tensor + # move_tensor(params, device='cpu') + for param in group_params: + param.data = param.data.cpu() + + # flatten the reordered tensors + for rank in range(self._world_size): + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + with torch.no_grad(): + flat_tensor = flatten(tensor_list) + flat_tensor = flat_tensor.data.cuda() + self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) + + # sync parameters + for rank in range(self._world_size): + flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id) + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) + + # create a copy of fp32 weights of the parameters for which this rank is responsible + fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id) + fp32_flat_current_rank = fp16_flat_current_rank.float() + device = 'cpu' if self._cpu_offload else get_current_device() + fp32_flat_current_rank = fp32_flat_current_rank.to(device) + fp32_flat_current_rank.requires_grad = True + self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group['params'] = [fp32_flat_current_rank] + + # set reduction state + for param in self._fp16_param_groups[group_id]: + self._param_store.set_param_reduction_state(param, False) + + # intialize communication stream for + # communication-compuation overlapping + if self._overlap_communication: + self._comm_stream = torch.cuda.Stream() + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + if self._overlap_communication or self._partition_grads: + self._attach_reduction_hook() + + self._initialize_optimizer_states() + + @property + def loss_scale(self): + return self.grad_scaler.scale + + @property + def num_param_groups(self): + return len(self._fp16_param_groups) + + def _partition_param_list(self, param_list): + params_per_rank = [[] for _ in range(self._world_size)] + numel_per_rank = [0 for _ in range(self._world_size)] + + # partititon the parameters in a greedy fashion + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for param in sorted_params: + # allocate this parameter to the rank with + # the smallest numel for load balancing purpose + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + numel_per_rank[rank_to_go] += param.numel() + + if self._verbose: + self._logger.info(f'Number of elements on ranks: {numel_per_rank}', + ranks=[0], + parallel_mode=self._dp_parallel_mode) + return params_per_rank + + def _initialize_optimizer_states(self): + # create a dummy zero tensor which has the same shape as that of the param + # set this dummpy zero tensor as grad + for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)): + fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp32_partition_grad = torch.zeros_like(fp32_partition_param) + fp32_partition_param.grad = fp32_partition_grad + + # we do not need log information for optimizer, so comment them + # update the parameter with zero gradients for initialization of optimizer states + # self._optimizer.step() + + # remove the grad of the paramter to save memory + # for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): + # fp32_flat_tensor.grad = None + + def _sanity_checks(self): + assert torch.cuda.is_available(), 'CUDA is required' + for param_group in self._optimizer.param_groups: + group_params = param_group['params'] + for param in group_params: + assert param.dtype == self._dtype, \ + f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + ########################################################### + # Backward Reduction Hook + ########################################################### + + def _attach_reduction_hook(self): + # we iterate over the fp16 params + # on each param, we register a hook to its AccumulateGrad object + for group_id in range(self.num_param_groups): + param_group = self._fp16_param_groups[group_id] + for param in param_group: + if param.requires_grad: + # determines the reduction destionation rank + # this is only valid for stage 2 + # dst_rank = None means using all-reduce + # else using reduce + if self._partition_grads: + reduce_rank = self._param_store.get_param_rank(param) + else: + reduce_rank = None + + def _define_and_attach(param, reduce_rank): + # get the AccumulateGrad object of the param itself + accum_grad_obj = get_grad_accumulate_object(param) + self._grad_store.add_accumulate_grad_object(accum_grad_obj) + + reduction_func = partial(self._reduce_and_remove_grads_by_bucket, + param=param, + reduce_rank=reduce_rank) + + # define hook + # NOT IMPORTANT BUT GOOD TO KNOW: + # args here is not grad, but allow_unreacable and accumulate_grad + def reduce_grad_hook(*args): + reduction_func() + + accum_grad_obj.register_hook(reduce_grad_hook) + + _define_and_attach(param, reduce_rank) + + def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._reduce_grads_in_bucket(reduce_rank) + + # the param must not be reduced to ensure correctness + is_param_reduced = self._param_store.is_param_reduced(param) + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ + + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # the param must have grad for reduction + assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + + self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) + self._bucket_store.add_grad(param.grad, reduce_rank) + self._bucket_store.add_param(param, reduce_rank) + + def _reduce_grads_in_bucket(self, reduce_rank=None): + # reduce grads + self._reduce_grads_by_rank(reduce_rank=reduce_rank, + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + + # use communication stream if overlapping + # communication with computation + if self._overlap_communication: + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + + for param in params_in_bucket: + # the is_param_reduced flag should be False showing that + # this param is not reduced before calling self._reduce_grads_by_rank + is_param_reduced = self._param_store.is_param_reduced(param) + + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # update the flag + self._param_store.set_param_reduction_state(param, True) + + # if partition grads = True + # we do not keep the gradient after reduction + if self._partition_grads and not self._param_store.belongs_to_current_rank(param): + if self._overlap_communication: + # we need to keep this gradient for now as reduction may + # be completed yet since it is using a different cuda stream + self._param_store.add_previous_reduced_param(param) + else: + param.grad = None + + self._bucket_store.reset_by_rank(reduce_rank) + + def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): + grad_buckets_by_dtype = split_half_float_double(grads) + + for tensor_list in grad_buckets_by_dtype: + self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank) + + ############################## + # Reduction Utility Function # + ############################## + def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank): + param_bucket = TensorBucket(size=bucket_size) + + for tensor in tensor_list: + param_bucket.add_to_bucket(tensor, allow_oversize=True) + + if param_bucket.is_full_or_oversized(): + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + param_bucket.empty() + + if not param_bucket.is_empty(): + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + + def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + flat = bucket.flatten() + reduced_flat = reduce_tensor(tensor=flat, + dtype=self._communication_dtype, + dst_rank=reduce_rank, + parallel_mode=self._dp_parallel_mode) + + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._local_rank: + bucket.unflatten_and_copy(reduced_flat) + + ################################ + # torch.optim.Optimizer methods + ################################ + + def backward(self, loss, retain_graph=False): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + # finish gradient reduction + if not self._partition_grads: + self._reduce_grad_stage1() + else: + # TODO: support async comm in reduce + self._reduce_grad_stage2() + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + for group_id, param_group in self._fp16_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + + #################### + # Update Parameter # + #################### + + def step(self, closure=None): + assert closure is None, 'closure is not supported by step()' + + # check for overflow + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) + + # update loss scale if overflow occurs + if found_inf: + self._grad_store._averaged_gradients = dict() + self.zero_grad() + return + + # copy the grad of fp16 param to fp32 param + single_grad_partition_groups = [] + norm_groups = [] + + for group_id in range(self.num_param_groups): + # compute norm + norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id], + params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id, + rank=self._local_rank), + dp_group=self._dp_group, + mp_group=self._mp_group) + norm_groups.append(norm_group) + + # create flat gradient for the flat fp32 params + fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) + flat_fp16_avg_grads = flatten(fp16_avg_grads) + + dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype + flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) + + param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape + assert param_shape == flat_fp32_avg_grads.shape, \ + f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}' + + single_grad_partition_groups.append(flat_fp32_avg_grads) + device = self._fp32_flat_param_groups_of_current_rank[group_id].device + self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) + self._grad_store._averaged_gradients[group_id] = [] + self._grad_store._averaged_gradients[group_id] = [] + + # unscale and clip grads + global_norm = calculate_global_norm_from_list(norm_list=norm_groups) + self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + + # update the parameters + self._optimizer.step() + # release the fp32 grad + release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) + + # update fp16 partition updated by the current rank + for group_id in range(len(self._fp16_param_groups)): + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp16_param.data.copy_(fp32_param) + + # broadcast the updated model weights + handles = [] + for group_id in range(self.num_param_groups): + for rank in range(self._world_size): + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True) + handles.append(handle) + + for handle in handles: + handle.wait() + + ################## + # FP16 Utilities # + ################## + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group_id in range(len(self._fp16_param_groups)): + for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group) + + # all-reduce over model parallel group + if self._mp_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group) + + if self._found_overflow.item() > 0: + return True + else: + return False + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + combined_scale = self.loss_scale + + if self._clip_grad_norm > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + combined_scale = clip * self.loss_scale + + for grad in grad_groups_flat: + grad.data.mul_(1. / combined_scale) + + ############################ + # Gradient Synchronization # + ############################ + + def sync_grad(self): + # update param already reduced flag + reduction_states = self._param_store.get_param_reduction_states() + for tensor, state in reduction_states.items(): + reduction_states[tensor] = False + + # accumulate gradient + avg_gradients = self._grad_store._averaged_gradients + for group_id in range(self.num_param_groups): + param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) + + if group_id not in avg_gradients: + avg_gradients[group_id] = [] + + param_idx = 0 + for param in param_group: + if param.grad is not None: + if len(avg_gradients[group_id]) == param_idx: + avg_gradients[group_id].append(param.grad) + else: + avg_gradients[group_id][param_idx].add_(param.grad) + param_idx += 1 + + # the gradients needed are stored in the avg_gradients buffer + # thus, can clear this + self.zero_grad() + + def _reduce_grad_stage1(self): + # if not overlapping communication (no reduction hook is attached) + # we need to manually reduce these gradients + if not self._overlap_communication: + for group_id in range(len(self._fp16_param_groups)): + param_group = self._fp16_param_groups[group_id] + for param in param_group: + if param.grad is not None: + self._reduce_and_remove_grads_by_bucket(param) + + # we need to reduce the gradients + # left in the communication bucket + self._reduce_grads_in_bucket() + + def _reduce_grad_stage2(self): + # when partition_grads is True, reduction hooks + # are attached in the __init__ function, so we + # only need to reduce the gradients + # left in the communication bucket + for reduce_rank in range(self._world_size): + self._reduce_grads_in_bucket(reduce_rank) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..401ff988df4acd6b97d8a2c585a1ab2d1f18b45e --- /dev/null +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -0,0 +1,386 @@ +from enum import Enum +from os import stat +from typing import Dict, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_tensor_mem_usage) +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from torch.optim import Optimizer +from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState) +from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy + + +class OptimState(Enum): + SCALED = 1 + UNSCALED = 2 + + +class ShardedOptimizerV2(ColossalaiOptimizer): + """A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO). + + By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. + + We apply the Device-aware Operator Placement technique for OS placement from the following paper. + + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory, + which is detected by a runtime memory tracer. + + We place as many OS chunks in the margin space as possible. + + The size of margin space can be controlled by ``gpu_margin_mem_ratio``. + If it is set as ``0.0``, it is the same as classical ZeRO optimizer. + + Note: + You must use ``ShardedOptimizerV2`` with ``ShardedModelV2``. + + Note: + Make sure you set ``tensor_placement_policy`` in ``ShardedModelV2`` to `"auto"`, + if you set ``gpu_margin_mem_ratio > 0``. + + Args: + sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the + shard strategy provided by sharded model to shard param fp32 tensors. + optimizer (Optimizer): An Optimizer instance. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None. + mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None. + + .. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, + sharded_model: ShardedModelV2, + optimizer: Optimizer, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + dp_process_group: Optional[ProcessGroup] = None, + mp_process_group: Optional[ProcessGroup] = None, + verbose: bool = False) -> None: + assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' + assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' + + super().__init__(optimizer) + self.shard_strategy = sharded_model.shard_strategy + self.model: ShardedModelV2 = sharded_model + + self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid + # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, + # and it must set `num_fp32_shards_per_param` correctly + self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( + optimizer, 'num_fp32_shards_per_param', 0) >= 2 + self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu') + self.optim_state: OptimState = OptimState.UNSCALED + self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) + self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) + # Grad scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) + self._logger = get_dist_logger("ShardedOptimizerV2") + self._verbose = verbose + + # Store fp32 param shards + self._register_master_weight() + if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, + AutoTensorPlacementPolicy): + self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', + ranks=[0]) + + if self._verbose: + self._logger.debug( + f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) + + self._use_memory_tracer = self.model.use_memory_tracer + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def get_memory_usage(self) -> Tuple[int, int]: + """ Get the memory usage of the optimizer. Including master_params (param fp32), + momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``) + + Returns: + Tuple[int, int]: cuda/cpu memory usage in Byte. + """ + cuda_use = 0 + cpu_use = 0 + + def update_mem_use(t): + nonlocal cuda_use + nonlocal cpu_use + t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t) + cuda_use += t_cuda_use + cpu_use += t_cpu_use + + for _, p_fp32 in self.master_params.items(): + update_mem_use(p_fp32) + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for k, v in state.items(): + update_mem_use(v) + + return cuda_use, cpu_use + + def zero_grad(self, *args, **kwargs): + self._zero_grad() + + def backward(self, loss: Tensor) -> None: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.model.backward(loss) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED + self.model.backward_by_grad(tensor, grad) + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + if self.optim_state == OptimState.SCALED: + self._prepare_grads() + self._unscale_grads() + return super().clip_grad_norm(model, max_norm) + + def step(self, *args, **kwargs): + + # unscale grads if scaled + if self.optim_state == OptimState.SCALED: + self._prepare_grads() + self._unscale_grads() + + self._maybe_move_fp32_shards() + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) + + if found_inf: + self._logger.warning('found inf during ShardedOptimV2 step') + self._zero_grad(recover_data=True) + return + + self._point_param_fp16_to_master_param() + + if self._verbose: + gpu_mem, cpu_mem = self.get_memory_usage() + self._logger.debug( + f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", + ranks=[0]) + ret = self.optim.step(*args, **kwargs) + + if self._verbose: + gpu_mem, cpu_mem = self.get_memory_usage() + self._logger.debug( + f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", + ranks=[0]) + + self._copy_master_model_to_model_fp16() + return ret + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(self.model.overflow_counter) + + # all-reduce across dp group + dist.all_reduce(self._found_overflow, group=self.dp_process_group) + + # all-reduce over model parallel group + dist.all_reduce(self._found_overflow, group=self.mp_process_group) + + return self._found_overflow.item() > 0 + + def _unscale_grads(self): + assert self.optim_state == OptimState.SCALED + for group in self.optim.param_groups: + for p in group['params']: + if p.grad is not None: + p.grad.data.div_(self.loss_scale) + self.optim_state = OptimState.UNSCALED + + def _zero_grad(self, recover_data: bool = False): + """zero grad and maybe recover fp16 params + When `reuse_fp16_shard` is enabled, + p.colo_attr.sharded_data_tensor stores grad here. + We have to recover them from fp32 params. + + Args: + recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False. + """ + # We must set grad to None + # Because grad here is sharded + # But next backward pass will create a full grad first + # Which leads to wrong accumulation + self.optim.zero_grad(set_to_none=True) + for group in self.optim.param_groups: + for p in group['params']: + # p.colo_attr.sharded_data_tensor stores grad now + # we have to recover fp16 param + reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0) + if recover_data and reuse_fp16_shard: + self._copy_master_param_to_param_fp16(p) + else: + # release saved gradient + p.colo_attr.saved_grad.set_null() + self.model.overflow_counter = 0 # set overflow counter to zero + + def sync_grad(self): + pass + + def _register_master_weight(self): + self.master_params: Dict[Parameter, StatefulTensor] = {} + for group in self.optim.param_groups: + for p in group['params']: + assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated + if shard_flag: + # we always shard replicated paramters + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) + self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device))) + if shard_flag: + # In this branch, there's no need to shard param + # So we gather here + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + + def _maybe_move_fp32_shards(self): + if self._should_move_fp32_shards_h2d: + self._should_move_fp32_shards_h2d = False + available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio + fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param + fp32_shards_used_cuda_margin_mem = 0 + for group in self.optim.param_groups: + for p in group['params']: + if p.colo_attr.saved_grad.is_null(): + continue + shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() + if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: + colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device()) + colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device()) + p.colo_attr.offload_grad = False + fp32_shards_used_cuda_margin_mem += shard_mem + state = self.optim.state[p] + for k, v in state.items(): + if isinstance(v, Tensor): + state[k] = v.cuda() + + def _prepare_grads(self): + for group in self.optim.param_groups: + for p in group['params']: + if p.colo_attr.saved_grad.is_null(): + continue + p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE) + # If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU + if not p.colo_attr.offload_grad: + colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device()) + # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation + # If we change p.grad directly + # it may raise error because of different shape/dtype/device of p.data and p.grad + # We just set p.data = p.colo_attr.saved_grad.payload here + p.data = p.colo_attr.grad_payload + p.grad = p.colo_attr.grad_payload + # Set p.data to empty tensor, in case of memory leaking + p.colo_attr.set_data_none() + + def _point_param_fp16_to_master_param(self): + # assign master param pointers to p.data. + # We will not trigger data copy here. + for group in self.optim.param_groups: + for p in group['params']: + self.master_params[p].trans_state(TensorState.COMPUTE) + p.data = self.master_params[p].payload + # Now p.data is sharded + # So optimizer states are sharded naturally + + def _copy_master_model_to_model_fp16(self): + # Copy master param data (fp32) to payload of colo_attr (fp16) + # TODO() improve efficiency by gathering tensors into a chunk and transfering + # a chunk. + for group in self.optim.param_groups: + for p in group['params']: + self._copy_master_param_to_param_fp16(p) + + def _copy_master_param_to_param_fp16(self, p): + # flush gradient + if p.colo_attr.sharded_data_tensor.payload_size == 0: + # here reuse_fp16_shard is True + # in order to use copy below, we should give sharded data tensor a payload + p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad) + else: + p.colo_attr.saved_grad.set_null() + + p.data = self.master_params[p].payload + + # we need to allocate new memory for keep_not_shard paramters + # in order to use copy, otherwise, the sizes of tensor is not compatible + if p.colo_attr.data_payload.numel() != p.data.numel(): + p.colo_attr.data_payload_reset( + torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) + + # TODO() optimize this line CPU (fp32) -> GPU (fp16) + p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) + p.colo_attr.set_data_none() + + if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: + # We gather full fp16 param here + p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + + self.master_params[p].trans_state(TensorState.HOLD) + + def state_dict(self): + optim_state_dict = super().state_dict() + scaler_state_dict = self.grad_scaler.state_dict() + optim_state_dict['scaler'] = scaler_state_dict + return optim_state_dict + + def load_state_dict(self, *args, **kwargs): + if 'scaler' not in args[0]: + self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + else: + scaler_state_dict = args[0].pop('scaler') + self.grad_scaler.load_state_dict(scaler_state_dict) + super().load_state_dict(*args, **kwargs) + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for k, v in state.items(): + if isinstance(v, Tensor): + state[k] = v.to(dtype=self.master_params[p].dtype, device=self.master_params[p].device) diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5642a504acf7da85db12c8464c64cf793b8bd134 --- /dev/null +++ b/colossalai/zero/sharded_param/__init__.py @@ -0,0 +1,4 @@ +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 + +__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py new file mode 100644 index 0000000000000000000000000000000000000000..db0f2d1494313ee8aab415384f00e8fd748a7e4c --- /dev/null +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -0,0 +1,108 @@ +import torch +from typing import Optional, Tuple +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.gemini.tensor_utils import colo_tensor_mem_usage +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState +from typing import List + +EMPTY_TENSOR_DICT = {} + + +def get_empty_tensor(device: torch.device, dtype: torch.dtype): + key = (device, dtype) + if key not in EMPTY_TENSOR_DICT: + EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device) + + return EMPTY_TENSOR_DICT[key] + + +class ShardedParamV2(object): + + def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: + self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) + self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) + # This attribute must be initialized in ShardedModel + self.offload_grad: bool = False + + # make sure the shared param is the only owner of payload + # The param.data maybe used to init the other part of the model. + # For example: File "resnet.py", line 190, in __init__ + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # So we can not empty the .data at this time + self.param = param + if set_data_none: + self.set_data_none() + + def get_payload_tensors(self) -> List[StatefulTensor]: + """returns stateful tensors kept by this class. + """ + return [self._sharded_data_tensor] + + def set_data_none(self): + self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype) + + def set_grad_none(self): + self.saved_grad.set_null() + + @property + def sharded_data_tensor(self): + return self._sharded_data_tensor + + @property + def data_payload(self): + assert not self.sharded_data_tensor.is_null() + return self.sharded_data_tensor.payload + + @property + def grad_payload(self): + assert not self.saved_grad.is_null() + return self.saved_grad.payload + + @property + def param_is_sharded(self): + return self.sharded_data_tensor.is_sharded + + def data_payload_reset(self, tensor: torch.Tensor): + assert type(tensor) is torch.Tensor + assert tensor.requires_grad is False + self.sharded_data_tensor.payload_reset(tensor) + + def grad_payload_reset(self, tensor: torch.Tensor): + assert type(tensor) is torch.Tensor + assert tensor.requires_grad is False + self.saved_grad.payload_reset(tensor) + + def get_memory_usage(self) -> Tuple[int, int]: + """ + get the memory usage of the param, including data and grad + Returns: + Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte + """ + cuda_mem_use, cpu_mem_use = 0, 0 + + def _update_mem_use(t: Optional[torch.Tensor]): + if t is None: + return + assert isinstance(t, torch.Tensor) + nonlocal cuda_mem_use + nonlocal cpu_mem_use + t_cuda, t_cpu = colo_tensor_mem_usage(t) + cuda_mem_use += t_cuda + cpu_mem_use += t_cpu + + address_set = set() + _update_mem_use(self.data_payload) + address_set.add(self.data_payload.data_ptr()) + + if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: + _update_mem_use(self.grad_payload) + address_set.add(self.saved_grad.data_ptr()) + + if self.param.data is not None and self.param.data.data_ptr() not in address_set: + _update_mem_use(self.param.data) + address_set.add(self.param.data.data_ptr()) + + if self.param.grad is not None and self.param.grad.data_ptr() not in address_set: + _update_mem_use(self.param.grad) + + return cuda_mem_use, cpu_mem_use diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..77f4aec30f3251045b7ee9a61bb7a17b7e1fb112 --- /dev/null +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -0,0 +1,39 @@ +import torch +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState + + +class ShardedTensor(StatefulTensor): + + def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: + r""" + A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. + """ + assert tensor.requires_grad is False + super().__init__(tensor, state) + + # kept the shape, numel and dtype of the init tensor. + self._origin_shape = tensor.shape + self._origin_numel = tensor.numel() + self._origin_dtype = tensor.dtype + self._is_sharded = False + + @property + def dtype(self) -> torch.dtype: + assert self._payload.dtype == self._origin_dtype + return self._payload.dtype + + @property + def origin_numel(self) -> int: + return self._origin_numel + + @property + def origin_shape(self) -> int: + return self._origin_shape + + @property + def is_sharded(self): + return self._is_sharded + + @is_sharded.setter + def is_sharded(self, flag: bool): + self._is_sharded = flag diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e68722895741f813a4e8950b64dde22a130218 --- /dev/null +++ b/colossalai/zero/utils/__init__.py @@ -0,0 +1,3 @@ +from .zero_hook import ZeroHook + +__all__ = ['ZeroHook'] \ No newline at end of file diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/utils/gemini_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..35569c7172b389133364a0cd7d0108b26c26148f --- /dev/null +++ b/colossalai/zero/utils/gemini_hook.py @@ -0,0 +1,67 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.gemini import TensorState +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor.param_op_hook import ColoParamOpHook + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class GeminiZeROHook(ColoParamOpHook): + + def __init__(self, gemini_manager: GeminiManager) -> None: + super().__init__() + self._gemini_manager = gemini_manager + self._chunk_manager = gemini_manager.chunk_manager + self._training_phase = TrainingPhase.FORWARD + + def pre_op(self, params): + params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] + chunks = self._chunk_manager.get_chunks(params) + for p in params: + self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) + self._gemini_manager.sample_overall_data() + self._gemini_manager.adjust_layout(chunks) + for chunk in chunks: + self._chunk_manager.access_chunk(chunk) + + # record cuda model data of the current OP + self._gemini_manager.record_model_data_volume() + + def post_op(self, params): + params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] + for p in params: + tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD + self._chunk_manager.trans_tensor_state(p, tensor_state) + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..87bf2c0f5086e7797c947f523d5a81f98da9dd0d --- /dev/null +++ b/colossalai/zero/utils/zero_hook.py @@ -0,0 +1,118 @@ +from typing import Optional + +import torch +import torch.distributed as dist + +from colossalai.gemini.memory_tracer import MemStatsCollector +from colossalai.gemini.ophooks import BaseOpHook +from colossalai.gemini.stateful_tensor import TensorState +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.logging import get_dist_logger +from colossalai.registry import OPHOOKS +from colossalai.utils import get_current_device +from colossalai.zero.shard_utils import BaseShardStrategy + + +@OPHOOKS.register_module +class ZeroHook(BaseOpHook): + """ + A hook to process sharded param for ZeRO method. + Warning: this class has been deprecated after version 0.1.12 + """ + + def __init__(self, + shard_strategy: BaseShardStrategy, + memstarts_collector: Optional[MemStatsCollector] = None, + stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, + process_group: Optional[dist.ProcessGroup] = None): + super().__init__() + self.logger = get_dist_logger("ZeROHook") + self.shard_strategy = shard_strategy + self.process_group = process_group + + # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU + self.computing_device = get_current_device() + + self._memstarts_collector = memstarts_collector + self._stateful_tensor_mgr = stateful_tensor_mgr + + def gather_parameters(self, module: torch.nn.Module): + # gather sharded parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.gather(tensor_list, self.process_group) + + def shard_parameters(self, module: torch.nn.Module): + # shard gathered parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.shard(tensor_list, self.process_group) + + def adjust_module_data(self, module: torch.nn.Module): + # record overall data statistics + if self._memstarts_collector: + self._memstarts_collector.sample_overall_data() + + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + + # adjust stateful tensor to get enough CUDA memory + self._stateful_tensor_mgr.adjust_layout() + + # record model data statistics + if self._memstarts_collector: + self._memstarts_collector.record_model_data_volume() + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + self.adjust_module_data(module) + self.gather_parameters(module) + for param in module.parameters(recurse=False): + param.data = param.colo_attr.data_payload + assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" + + def post_fwd_exec(self, module: torch.nn.Module, *args): + + # change tensor state to HOLD_AFTER_FWD + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) + + self.shard_parameters(module) + + # remove torch payload + for param in module.parameters(recurse=False): + param.colo_attr.set_data_none() + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + self.adjust_module_data(module) + self.gather_parameters(module) + for param in module.parameters(recurse=False): + param.data = param.colo_attr.data_payload + assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" + + def post_bwd_exec(self, module: torch.nn.Module, input): + + # change tensor state to HOLD_AFTER_BWD + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + + self.shard_parameters(module) + + # remove torch payload + for param in module.parameters(recurse=False): + param.colo_attr.set_data_none() + + def pre_iter(self): + pass + + def post_iter(self): + if self._stateful_tensor_mgr: + self.logger.debug( + f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}", + ranks=[0]) + self._stateful_tensor_mgr.finish_iter() diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..bcb7c0fffbb3e306fdb230d1b921ed17ba3834e5 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,24 @@ +FROM hpcaitech/cuda-conda:11.3 + +# install torch +RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch + +# install apex +RUN git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ + +# install colossalai +RUN git clone https://github.com/hpcaitech/ColossalAI.git \ + && cd ./ColossalAI \ + && pip install -v --no-cache-dir . + +# install titans +RUN pip install --no-cache-dir titans + +# install tensornvme +RUN conda install cmake && \ + git clone https://github.com/hpcaitech/TensorNVMe.git && \ + cd TensorNVMe && \ + pip install -r requirements.txt && \ + pip install -v --no-cache-dir . \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..9f43a48d64206ee9af88cb4fef87363921998174 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,26 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = .build +SPHINXAPIDOC ?= sphinx-apidoc +SPHINX_APIDOC_OPTIONS = members +SPHINX_APIDOC_TEMPLATEDIR = _templates/apidoc + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile apidoc + +apidoc: + @SPHINX_APIDOC_OPTIONS=$(SPHINX_APIDOC_OPTIONS) $(SPHINXAPIDOC) -f -T -e -M -d 2 -t $(SPHINX_APIDOC_TEMPLATEDIR) -o ./colossalai ../colossalai +# @$(SPHINXAPIDOC) -f -o ./model_zoo ../model_zoo +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/css/rtd_theme.css b/docs/_static/css/rtd_theme.css new file mode 100644 index 0000000000000000000000000000000000000000..caf42dc5aaab93f26d417dd132001b5e03e849e0 --- /dev/null +++ b/docs/_static/css/rtd_theme.css @@ -0,0 +1,3 @@ +.wy-nav-content { + max-width: 80%; +} \ No newline at end of file diff --git a/docs/_templates/apidoc/module.rst_t b/docs/_templates/apidoc/module.rst_t new file mode 100644 index 0000000000000000000000000000000000000000..d9a50e6b9752a1b04ef1317c33075e8c19fc97cd --- /dev/null +++ b/docs/_templates/apidoc/module.rst_t @@ -0,0 +1,9 @@ +{%- if show_headings %} +{{- basename | e | heading }} + +{% endif -%} +.. automodule:: {{ qualname }} +{%- for option in automodule_options %} + :{{ option }}: +{%- endfor %} + diff --git a/docs/_templates/apidoc/package.rst_t b/docs/_templates/apidoc/package.rst_t new file mode 100644 index 0000000000000000000000000000000000000000..83742b3f7c66c10e0ebbe78718dea91e34d050a5 --- /dev/null +++ b/docs/_templates/apidoc/package.rst_t @@ -0,0 +1,52 @@ +{%- macro automodule(modname, options) -%} +.. automodule:: {{ modname }} +{%- for option in options %} + :{{ option }}: +{%- endfor %} +{%- endmacro %} + +{%- macro toctree(docnames) -%} +.. toctree:: + :maxdepth: {{ maxdepth }} +{% for docname in docnames %} + {{ docname }} +{%- endfor %} +{%- endmacro %} + +{%- if is_namespace %} +{{- pkgname | e | heading }} +{% else %} +{{- pkgname | e | heading }} +{% endif %} + +{%- if is_namespace %} +.. py:module:: {{ pkgname }} +{% endif %} + +{%- if modulefirst and not is_namespace %} +{{ automodule(pkgname, automodule_options) }} +{% endif %} + +{%- if subpackages %} +{{ toctree(subpackages) }} +{% endif %} + +{%- if submodules %} +{% if separatemodules %} +{{ toctree(submodules) }} +{% else %} +{%- for submodule in submodules %} +{% if show_headings %} +{{- submodule | e | heading(2) }} +{% endif %} +{{ automodule(submodule, automodule_options) }} +{% endfor %} +{%- endif %} +{%- endif %} + +{%- if not modulefirst and not is_namespace %} +Module contents +--------------- + +{{ automodule(pkgname, automodule_options) }} +{% endif %} diff --git a/docs/_templates/apidoc/toc.rst_t b/docs/_templates/apidoc/toc.rst_t new file mode 100644 index 0000000000000000000000000000000000000000..f0877eeb2f85324a48eb63d793a536a8cfdb4a00 --- /dev/null +++ b/docs/_templates/apidoc/toc.rst_t @@ -0,0 +1,8 @@ +{{ header | heading }} + +.. toctree:: + :maxdepth: {{ maxdepth }} +{% for docname in docnames %} + {{ docname }} +{%- endfor %} + diff --git a/docs/colossalai/colossalai.amp.amp_type.rst b/docs/colossalai/colossalai.amp.amp_type.rst new file mode 100644 index 0000000000000000000000000000000000000000..067af7d8c51a88ca94140b5b79fbbce7beccf41f --- /dev/null +++ b/docs/colossalai/colossalai.amp.amp_type.rst @@ -0,0 +1,5 @@ +colossalai.amp.amp\_type +======================== + +.. automodule:: colossalai.amp.amp_type + :members: diff --git a/docs/colossalai/colossalai.amp.apex_amp.apex_amp.rst b/docs/colossalai/colossalai.amp.apex_amp.apex_amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..cba7e00625a4d6d018e1416cbdc984e659a4f345 --- /dev/null +++ b/docs/colossalai/colossalai.amp.apex_amp.apex_amp.rst @@ -0,0 +1,5 @@ +colossalai.amp.apex\_amp.apex\_amp +================================== + +.. automodule:: colossalai.amp.apex_amp.apex_amp + :members: diff --git a/docs/colossalai/colossalai.amp.apex_amp.rst b/docs/colossalai/colossalai.amp.apex_amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..7116a538b4c1d227354d1a16a64ce00165427cc3 --- /dev/null +++ b/docs/colossalai/colossalai.amp.apex_amp.rst @@ -0,0 +1,11 @@ +colossalai.amp.apex\_amp +======================== + +.. automodule:: colossalai.amp.apex_amp + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.amp.apex_amp.apex_amp diff --git a/docs/colossalai/colossalai.amp.naive_amp.grad_scaler.rst b/docs/colossalai/colossalai.amp.naive_amp.grad_scaler.rst new file mode 100644 index 0000000000000000000000000000000000000000..12d477825659803f22dbcf42f16f0975c9a679aa --- /dev/null +++ b/docs/colossalai/colossalai.amp.naive_amp.grad_scaler.rst @@ -0,0 +1,8 @@ +colossalai.amp.naive\_amp.grad\_scaler +====================================== + +.. automodule:: colossalai.amp.naive_amp.grad_scaler + :members: + + + diff --git a/docs/colossalai/colossalai.amp.naive_amp.naive_amp.rst b/docs/colossalai/colossalai.amp.naive_amp.naive_amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..e20f22b2e386effc3f68c5ef49c490dbac75aaea --- /dev/null +++ b/docs/colossalai/colossalai.amp.naive_amp.naive_amp.rst @@ -0,0 +1,5 @@ +colossalai.amp.naive\_amp.naive\_amp +==================================== + +.. automodule:: colossalai.amp.naive_amp.naive_amp + :members: diff --git a/docs/colossalai/colossalai.amp.naive_amp.rst b/docs/colossalai/colossalai.amp.naive_amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..fd364c05331ce2f5f43308438821f892d68cc6de --- /dev/null +++ b/docs/colossalai/colossalai.amp.naive_amp.rst @@ -0,0 +1,16 @@ +colossalai.amp.naive\_amp +========================= + +.. automodule:: colossalai.amp.naive_amp + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.amp.naive_amp.grad_scaler + + +.. toctree:: + :maxdepth: 2 + + colossalai.amp.naive_amp.naive_amp diff --git a/docs/colossalai/colossalai.amp.rst b/docs/colossalai/colossalai.amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..5ef4f36c13ac30b6accda7435a9e6d5b30f49a4e --- /dev/null +++ b/docs/colossalai/colossalai.amp.rst @@ -0,0 +1,18 @@ +colossalai.amp +============== + +.. automodule:: colossalai.amp + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.amp.apex_amp + colossalai.amp.naive_amp + colossalai.amp.torch_amp + + +.. toctree:: + :maxdepth: 2 + + colossalai.amp.amp_type diff --git a/docs/colossalai/colossalai.amp.torch_amp.rst b/docs/colossalai/colossalai.amp.torch_amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..f10095f136e091bad583e643151a1c6eae56351a --- /dev/null +++ b/docs/colossalai/colossalai.amp.torch_amp.rst @@ -0,0 +1,11 @@ +colossalai.amp.torch\_amp +========================= + +.. automodule:: colossalai.amp.torch_amp + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.amp.torch_amp.torch_amp diff --git a/docs/colossalai/colossalai.amp.torch_amp.torch_amp.rst b/docs/colossalai/colossalai.amp.torch_amp.torch_amp.rst new file mode 100644 index 0000000000000000000000000000000000000000..5f1549cb8d48aac1c0b51d03c9bd05aac0f16f46 --- /dev/null +++ b/docs/colossalai/colossalai.amp.torch_amp.torch_amp.rst @@ -0,0 +1,5 @@ +colossalai.amp.torch\_amp.torch\_amp +==================================== + +.. automodule:: colossalai.amp.torch_amp.torch_amp + :members: diff --git a/docs/colossalai/colossalai.builder.builder.rst b/docs/colossalai/colossalai.builder.builder.rst new file mode 100644 index 0000000000000000000000000000000000000000..85da78ab9e3de33e5eb4e7fcc9a659a7d3fa5952 --- /dev/null +++ b/docs/colossalai/colossalai.builder.builder.rst @@ -0,0 +1,5 @@ +colossalai.builder.builder +========================== + +.. automodule:: colossalai.builder.builder + :members: diff --git a/docs/colossalai/colossalai.builder.rst b/docs/colossalai/colossalai.builder.rst new file mode 100644 index 0000000000000000000000000000000000000000..61163d7c1ea1021af0f17b0362871656ab14d617 --- /dev/null +++ b/docs/colossalai/colossalai.builder.rst @@ -0,0 +1,11 @@ +colossalai.builder +================== + +.. automodule:: colossalai.builder + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.builder.builder diff --git a/docs/colossalai/colossalai.cli.benchmark.benchmark.rst b/docs/colossalai/colossalai.cli.benchmark.benchmark.rst new file mode 100644 index 0000000000000000000000000000000000000000..94a4170c85900e310b99572981b8c0f70c096bb5 --- /dev/null +++ b/docs/colossalai/colossalai.cli.benchmark.benchmark.rst @@ -0,0 +1,5 @@ +colossalai.cli.benchmark.benchmark +================================== + +.. automodule:: colossalai.cli.benchmark.benchmark + :members: diff --git a/docs/colossalai/colossalai.cli.benchmark.models.rst b/docs/colossalai/colossalai.cli.benchmark.models.rst new file mode 100644 index 0000000000000000000000000000000000000000..4e6290288d59e4e2d503cfbd8ddb7f5a22b3239e --- /dev/null +++ b/docs/colossalai/colossalai.cli.benchmark.models.rst @@ -0,0 +1,5 @@ +colossalai.cli.benchmark.models +=============================== + +.. automodule:: colossalai.cli.benchmark.models + :members: diff --git a/docs/colossalai/colossalai.cli.benchmark.rst b/docs/colossalai/colossalai.cli.benchmark.rst new file mode 100644 index 0000000000000000000000000000000000000000..80fb43dde04bc3a824a295415b3942ea313f40ea --- /dev/null +++ b/docs/colossalai/colossalai.cli.benchmark.rst @@ -0,0 +1,13 @@ +colossalai.cli.benchmark +======================== + +.. automodule:: colossalai.cli.benchmark + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.cli.benchmark.benchmark + colossalai.cli.benchmark.models + colossalai.cli.benchmark.utils diff --git a/docs/colossalai/colossalai.cli.benchmark.utils.rst b/docs/colossalai/colossalai.cli.benchmark.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..12fbaf2270ec3d06913035c694cd999a20ee5945 --- /dev/null +++ b/docs/colossalai/colossalai.cli.benchmark.utils.rst @@ -0,0 +1,5 @@ +colossalai.cli.benchmark.utils +============================== + +.. automodule:: colossalai.cli.benchmark.utils + :members: diff --git a/docs/colossalai/colossalai.cli.check.check_installation.rst b/docs/colossalai/colossalai.cli.check.check_installation.rst new file mode 100644 index 0000000000000000000000000000000000000000..95b2d02ca371514cb4fcc574291fbdc4fdb79821 --- /dev/null +++ b/docs/colossalai/colossalai.cli.check.check_installation.rst @@ -0,0 +1,5 @@ +colossalai.cli.check.check\_installation +======================================== + +.. automodule:: colossalai.cli.check.check_installation + :members: diff --git a/docs/colossalai/colossalai.cli.check.rst b/docs/colossalai/colossalai.cli.check.rst new file mode 100644 index 0000000000000000000000000000000000000000..262ae7ad31ba1ad664591986662181ca61ab761f --- /dev/null +++ b/docs/colossalai/colossalai.cli.check.rst @@ -0,0 +1,11 @@ +colossalai.cli.check +==================== + +.. automodule:: colossalai.cli.check + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.cli.check.check_installation diff --git a/docs/colossalai/colossalai.cli.cli.rst b/docs/colossalai/colossalai.cli.cli.rst new file mode 100644 index 0000000000000000000000000000000000000000..8f83973d5e0c983282b3d33cf27fbb71e2f90339 --- /dev/null +++ b/docs/colossalai/colossalai.cli.cli.rst @@ -0,0 +1,5 @@ +colossalai.cli.cli +================== + +.. automodule:: colossalai.cli.cli + :members: diff --git a/docs/colossalai/colossalai.cli.launcher.hostinfo.rst b/docs/colossalai/colossalai.cli.launcher.hostinfo.rst new file mode 100644 index 0000000000000000000000000000000000000000..5bcd9dd8cc4c814278e1a118428b3edd2bd60acf --- /dev/null +++ b/docs/colossalai/colossalai.cli.launcher.hostinfo.rst @@ -0,0 +1,5 @@ +colossalai.cli.launcher.hostinfo +================================ + +.. automodule:: colossalai.cli.launcher.hostinfo + :members: diff --git a/docs/colossalai/colossalai.cli.launcher.multinode_runner.rst b/docs/colossalai/colossalai.cli.launcher.multinode_runner.rst new file mode 100644 index 0000000000000000000000000000000000000000..223b0deac1f108c90d5386ced8c1d43d574b9f1c --- /dev/null +++ b/docs/colossalai/colossalai.cli.launcher.multinode_runner.rst @@ -0,0 +1,5 @@ +colossalai.cli.launcher.multinode\_runner +========================================= + +.. automodule:: colossalai.cli.launcher.multinode_runner + :members: diff --git a/docs/colossalai/colossalai.cli.launcher.rst b/docs/colossalai/colossalai.cli.launcher.rst new file mode 100644 index 0000000000000000000000000000000000000000..38bef61c790ddee207b856cc5f07da4d82582d8c --- /dev/null +++ b/docs/colossalai/colossalai.cli.launcher.rst @@ -0,0 +1,13 @@ +colossalai.cli.launcher +======================= + +.. automodule:: colossalai.cli.launcher + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.cli.launcher.hostinfo + colossalai.cli.launcher.multinode_runner + colossalai.cli.launcher.run diff --git a/docs/colossalai/colossalai.cli.launcher.run.rst b/docs/colossalai/colossalai.cli.launcher.run.rst new file mode 100644 index 0000000000000000000000000000000000000000..8506fb9e31657dd72017bc89f9c912e4ab8f503a --- /dev/null +++ b/docs/colossalai/colossalai.cli.launcher.run.rst @@ -0,0 +1,5 @@ +colossalai.cli.launcher.run +=========================== + +.. automodule:: colossalai.cli.launcher.run + :members: diff --git a/docs/colossalai/colossalai.cli.rst b/docs/colossalai/colossalai.cli.rst new file mode 100644 index 0000000000000000000000000000000000000000..8cc0dcb04aedc2ce713fa86b5cec078fe8648a00 --- /dev/null +++ b/docs/colossalai/colossalai.cli.rst @@ -0,0 +1,18 @@ +colossalai.cli +============== + +.. automodule:: colossalai.cli + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.cli.benchmark + colossalai.cli.check + colossalai.cli.launcher + + +.. toctree:: + :maxdepth: 2 + + colossalai.cli.cli diff --git a/docs/colossalai/colossalai.communication.collective.rst b/docs/colossalai/colossalai.communication.collective.rst new file mode 100644 index 0000000000000000000000000000000000000000..5015edf98901fa4077b6b11e4ab81a9979ae84c4 --- /dev/null +++ b/docs/colossalai/colossalai.communication.collective.rst @@ -0,0 +1,5 @@ +colossalai.communication.collective +=================================== + +.. automodule:: colossalai.communication.collective + :members: diff --git a/docs/colossalai/colossalai.communication.p2p.rst b/docs/colossalai/colossalai.communication.p2p.rst new file mode 100644 index 0000000000000000000000000000000000000000..79135bb8630f6dfa57f2f2857e4efaac046e0b5c --- /dev/null +++ b/docs/colossalai/colossalai.communication.p2p.rst @@ -0,0 +1,5 @@ +colossalai.communication.p2p +============================ + +.. automodule:: colossalai.communication.p2p + :members: diff --git a/docs/colossalai/colossalai.communication.ring.rst b/docs/colossalai/colossalai.communication.ring.rst new file mode 100644 index 0000000000000000000000000000000000000000..c218d4bed350f7af9e81cf4bfccb3bf94e273d94 --- /dev/null +++ b/docs/colossalai/colossalai.communication.ring.rst @@ -0,0 +1,5 @@ +colossalai.communication.ring +============================= + +.. automodule:: colossalai.communication.ring + :members: diff --git a/docs/colossalai/colossalai.communication.rst b/docs/colossalai/colossalai.communication.rst new file mode 100644 index 0000000000000000000000000000000000000000..5086fa663ec7e09ee12eb8393454dac783453354 --- /dev/null +++ b/docs/colossalai/colossalai.communication.rst @@ -0,0 +1,14 @@ +colossalai.communication +======================== + +.. automodule:: colossalai.communication + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.communication.collective + colossalai.communication.p2p + colossalai.communication.ring + colossalai.communication.utils diff --git a/docs/colossalai/colossalai.communication.utils.rst b/docs/colossalai/colossalai.communication.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..19a36cc9ff6f75448dc31bbdcbe41f1755bbbe83 --- /dev/null +++ b/docs/colossalai/colossalai.communication.utils.rst @@ -0,0 +1,5 @@ +colossalai.communication.utils +============================== + +.. automodule:: colossalai.communication.utils + :members: diff --git a/docs/colossalai/colossalai.constants.rst b/docs/colossalai/colossalai.constants.rst new file mode 100644 index 0000000000000000000000000000000000000000..330b3e8668ec88bc1610e0ff61b058e500c811f5 --- /dev/null +++ b/docs/colossalai/colossalai.constants.rst @@ -0,0 +1,5 @@ +colossalai.constants +==================== + +.. automodule:: colossalai.constants + :members: diff --git a/docs/colossalai/colossalai.context.config.rst b/docs/colossalai/colossalai.context.config.rst new file mode 100644 index 0000000000000000000000000000000000000000..2fb1b99d3e7af8af7cafb5f1ea7dd744aa888fc4 --- /dev/null +++ b/docs/colossalai/colossalai.context.config.rst @@ -0,0 +1,5 @@ +colossalai.context.config +========================= + +.. automodule:: colossalai.context.config + :members: diff --git a/docs/colossalai/colossalai.context.moe_context.rst b/docs/colossalai/colossalai.context.moe_context.rst new file mode 100644 index 0000000000000000000000000000000000000000..9027d19ff02328c3737e3428941064859117470f --- /dev/null +++ b/docs/colossalai/colossalai.context.moe_context.rst @@ -0,0 +1,5 @@ +colossalai.context.moe\_context +=============================== + +.. automodule:: colossalai.context.moe_context + :members: diff --git a/docs/colossalai/colossalai.context.parallel_context.rst b/docs/colossalai/colossalai.context.parallel_context.rst new file mode 100644 index 0000000000000000000000000000000000000000..d1c82c5178451e954115425e0c52620250371ccb --- /dev/null +++ b/docs/colossalai/colossalai.context.parallel_context.rst @@ -0,0 +1,5 @@ +colossalai.context.parallel\_context +==================================== + +.. automodule:: colossalai.context.parallel_context + :members: diff --git a/docs/colossalai/colossalai.context.parallel_mode.rst b/docs/colossalai/colossalai.context.parallel_mode.rst new file mode 100644 index 0000000000000000000000000000000000000000..f7ac137493fb4ad0f476c9ec82af719368bc1124 --- /dev/null +++ b/docs/colossalai/colossalai.context.parallel_mode.rst @@ -0,0 +1,5 @@ +colossalai.context.parallel\_mode +================================= + +.. automodule:: colossalai.context.parallel_mode + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_1d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_1d.rst new file mode 100644 index 0000000000000000000000000000000000000000..88cbf3ebadb3845028d3cc004981e47443a657fb --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_1d.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_1d +============================================================== + +.. automodule:: colossalai.context.process_group_initializer.initializer_1d + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_2d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_2d.rst new file mode 100644 index 0000000000000000000000000000000000000000..d99a2e1c31775187ce5db8239a04c749e750acb8 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_2d.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_2d +============================================================== + +.. automodule:: colossalai.context.process_group_initializer.initializer_2d + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_2p5d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_2p5d.rst new file mode 100644 index 0000000000000000000000000000000000000000..73d80e4431bbbefb094459cb53ff866239bc49b0 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_2p5d.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_2p5d +================================================================ + +.. automodule:: colossalai.context.process_group_initializer.initializer_2p5d + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_3d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_3d.rst new file mode 100644 index 0000000000000000000000000000000000000000..5cfba5ce0870e973930bcb5ea925185561b8509b --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_3d.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_3d +============================================================== + +.. automodule:: colossalai.context.process_group_initializer.initializer_3d + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_data.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_data.rst new file mode 100644 index 0000000000000000000000000000000000000000..55ad05f32b143b768d4d6b46add4e513f07a57fa --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_data.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_data +================================================================ + +.. automodule:: colossalai.context.process_group_initializer.initializer_data + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_model.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_model.rst new file mode 100644 index 0000000000000000000000000000000000000000..8f2d79369915a0c0a5b76e17a145be9e98311ab5 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_model.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_model +================================================================= + +.. automodule:: colossalai.context.process_group_initializer.initializer_model + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_pipeline.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_pipeline.rst new file mode 100644 index 0000000000000000000000000000000000000000..466d5143a02b58c86b8cb3adbf6461f2d59a759f --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_pipeline.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_pipeline +==================================================================== + +.. automodule:: colossalai.context.process_group_initializer.initializer_pipeline + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_sequence.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_sequence.rst new file mode 100644 index 0000000000000000000000000000000000000000..dab71cc3c3917c416e1d35ca6c5a6dafe4fdc1b9 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_sequence.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_sequence +==================================================================== + +.. automodule:: colossalai.context.process_group_initializer.initializer_sequence + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_tensor.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_tensor.rst new file mode 100644 index 0000000000000000000000000000000000000000..0c2d8d1e9daaa9a7392c048ffa7e7d3bf9e59342 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.initializer_tensor.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.initializer\_tensor +================================================================== + +.. automodule:: colossalai.context.process_group_initializer.initializer_tensor + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.process_group_initializer.rst b/docs/colossalai/colossalai.context.process_group_initializer.process_group_initializer.rst new file mode 100644 index 0000000000000000000000000000000000000000..3f98723c170b56c1b8b12dce96edb611cee1dc66 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.process_group_initializer.rst @@ -0,0 +1,5 @@ +colossalai.context.process\_group\_initializer.process\_group\_initializer +========================================================================== + +.. automodule:: colossalai.context.process_group_initializer.process_group_initializer + :members: diff --git a/docs/colossalai/colossalai.context.process_group_initializer.rst b/docs/colossalai/colossalai.context.process_group_initializer.rst new file mode 100644 index 0000000000000000000000000000000000000000..519337e9c71d2a19cbcc06af6e89868b897f5b53 --- /dev/null +++ b/docs/colossalai/colossalai.context.process_group_initializer.rst @@ -0,0 +1,20 @@ +colossalai.context.process\_group\_initializer +============================================== + +.. automodule:: colossalai.context.process_group_initializer + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.context.process_group_initializer.initializer_1d + colossalai.context.process_group_initializer.initializer_2d + colossalai.context.process_group_initializer.initializer_2p5d + colossalai.context.process_group_initializer.initializer_3d + colossalai.context.process_group_initializer.initializer_data + colossalai.context.process_group_initializer.initializer_model + colossalai.context.process_group_initializer.initializer_pipeline + colossalai.context.process_group_initializer.initializer_sequence + colossalai.context.process_group_initializer.initializer_tensor + colossalai.context.process_group_initializer.process_group_initializer diff --git a/docs/colossalai/colossalai.context.random.rst b/docs/colossalai/colossalai.context.random.rst new file mode 100644 index 0000000000000000000000000000000000000000..8d4b9c56af3cbeb64d14e7891a282a6f75ce7fa9 --- /dev/null +++ b/docs/colossalai/colossalai.context.random.rst @@ -0,0 +1,11 @@ +colossalai.context.random +========================= + +.. automodule:: colossalai.context.random + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.context.random.seed_manager diff --git a/docs/colossalai/colossalai.context.random.seed_manager.rst b/docs/colossalai/colossalai.context.random.seed_manager.rst new file mode 100644 index 0000000000000000000000000000000000000000..b71f35c2750c973eb9664bd96d3210fb0722c005 --- /dev/null +++ b/docs/colossalai/colossalai.context.random.seed_manager.rst @@ -0,0 +1,5 @@ +colossalai.context.random.seed\_manager +======================================= + +.. automodule:: colossalai.context.random.seed_manager + :members: diff --git a/docs/colossalai/colossalai.context.rst b/docs/colossalai/colossalai.context.rst new file mode 100644 index 0000000000000000000000000000000000000000..102a9e02eaa43d2ecd44e49aba9bc37c977a00dc --- /dev/null +++ b/docs/colossalai/colossalai.context.rst @@ -0,0 +1,21 @@ +colossalai.context +================== + +.. automodule:: colossalai.context + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.context.process_group_initializer + colossalai.context.random + + +.. toctree:: + :maxdepth: 2 + + colossalai.context.config + colossalai.context.moe_context + colossalai.context.parallel_context + colossalai.context.parallel_mode + colossalai.context.singleton_meta diff --git a/docs/colossalai/colossalai.context.singleton_meta.rst b/docs/colossalai/colossalai.context.singleton_meta.rst new file mode 100644 index 0000000000000000000000000000000000000000..ae4ceb314f32f732cd7ed49c61ba6fcd66107053 --- /dev/null +++ b/docs/colossalai/colossalai.context.singleton_meta.rst @@ -0,0 +1,5 @@ +colossalai.context.singleton\_meta +================================== + +.. automodule:: colossalai.context.singleton_meta + :members: diff --git a/docs/colossalai/colossalai.core.rst b/docs/colossalai/colossalai.core.rst new file mode 100644 index 0000000000000000000000000000000000000000..d9ddb76ed72a77ab98e70f2c0114baa07c16deef --- /dev/null +++ b/docs/colossalai/colossalai.core.rst @@ -0,0 +1,5 @@ +colossalai.core +=============== + +.. automodule:: colossalai.core + :members: diff --git a/docs/colossalai/colossalai.engine.gradient_accumulation.rst b/docs/colossalai/colossalai.engine.gradient_accumulation.rst new file mode 100644 index 0000000000000000000000000000000000000000..75fc0e9a24eb3fd04a48284a4cc5fcf740f66d64 --- /dev/null +++ b/docs/colossalai/colossalai.engine.gradient_accumulation.rst @@ -0,0 +1,5 @@ +colossalai.engine.gradient\_accumulation +======================================== + +.. automodule:: colossalai.engine.gradient_accumulation + :members: diff --git a/docs/colossalai/colossalai.engine.gradient_handler.rst b/docs/colossalai/colossalai.engine.gradient_handler.rst new file mode 100644 index 0000000000000000000000000000000000000000..27eb2b56a29f3b956a9557cc06f10ad9e4136144 --- /dev/null +++ b/docs/colossalai/colossalai.engine.gradient_handler.rst @@ -0,0 +1,11 @@ +colossalai.engine.gradient\_handler +=================================== + +.. automodule:: colossalai.engine.gradient_handler + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.engine.gradient_handler.utils diff --git a/docs/colossalai/colossalai.engine.gradient_handler.utils.rst b/docs/colossalai/colossalai.engine.gradient_handler.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..c8997e135b609d24dcf16ac80361767e6a794679 --- /dev/null +++ b/docs/colossalai/colossalai.engine.gradient_handler.utils.rst @@ -0,0 +1,5 @@ +colossalai.engine.gradient\_handler.utils +========================================= + +.. automodule:: colossalai.engine.gradient_handler.utils + :members: diff --git a/docs/colossalai/colossalai.engine.rst b/docs/colossalai/colossalai.engine.rst new file mode 100644 index 0000000000000000000000000000000000000000..3d194b70695ee7b8d99fde8879c0a0f8a5247557 --- /dev/null +++ b/docs/colossalai/colossalai.engine.rst @@ -0,0 +1,12 @@ +colossalai.engine +================= + +.. automodule:: colossalai.engine + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.engine.gradient_accumulation + colossalai.engine.gradient_handler + colossalai.engine.schedule diff --git a/docs/colossalai/colossalai.engine.schedule.rst b/docs/colossalai/colossalai.engine.schedule.rst new file mode 100644 index 0000000000000000000000000000000000000000..2909373f00020afea52509e8c92d3563f6a128ce --- /dev/null +++ b/docs/colossalai/colossalai.engine.schedule.rst @@ -0,0 +1,5 @@ +colossalai.engine.schedule +========================== + +.. automodule:: colossalai.engine.schedule + :members: diff --git a/docs/colossalai/colossalai.fx.passes.adding_split_node_pass.rst b/docs/colossalai/colossalai.fx.passes.adding_split_node_pass.rst new file mode 100644 index 0000000000000000000000000000000000000000..6799fdc658cdc32a2a2ff2ed48f1182918977269 --- /dev/null +++ b/docs/colossalai/colossalai.fx.passes.adding_split_node_pass.rst @@ -0,0 +1,5 @@ +colossalai.fx.passes.adding\_split\_node\_pass +============================================== + +.. automodule:: colossalai.fx.passes.adding_split_node_pass + :members: diff --git a/docs/colossalai/colossalai.fx.passes.meta_info_prop.rst b/docs/colossalai/colossalai.fx.passes.meta_info_prop.rst new file mode 100644 index 0000000000000000000000000000000000000000..4e51732ce83d961d6e14bcf2821a34e99493c94a --- /dev/null +++ b/docs/colossalai/colossalai.fx.passes.meta_info_prop.rst @@ -0,0 +1,5 @@ +colossalai.fx.passes.meta\_info\_prop +===================================== + +.. automodule:: colossalai.fx.passes.meta_info_prop + :members: diff --git a/docs/colossalai/colossalai.fx.passes.rst b/docs/colossalai/colossalai.fx.passes.rst new file mode 100644 index 0000000000000000000000000000000000000000..fac10b7680345a5f1009495a59c09d6709c7ed93 --- /dev/null +++ b/docs/colossalai/colossalai.fx.passes.rst @@ -0,0 +1,15 @@ +colossalai.fx.passes +==================== + +.. automodule:: colossalai.fx.passes + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.fx.passes.adding_split_node_pass + colossalai.fx.passes.meta_info_prop + colossalai.fx.passes.shard_1d_pass + colossalai.fx.passes.split_module + colossalai.fx.passes.utils diff --git a/docs/colossalai/colossalai.fx.passes.shard_1d_pass.rst b/docs/colossalai/colossalai.fx.passes.shard_1d_pass.rst new file mode 100644 index 0000000000000000000000000000000000000000..0942e96d46dc00465034a9e091eb9d0f749d6a5e --- /dev/null +++ b/docs/colossalai/colossalai.fx.passes.shard_1d_pass.rst @@ -0,0 +1,5 @@ +colossalai.fx.passes.shard\_1d\_pass +==================================== + +.. automodule:: colossalai.fx.passes.shard_1d_pass + :members: diff --git a/docs/colossalai/colossalai.fx.passes.split_module.rst b/docs/colossalai/colossalai.fx.passes.split_module.rst new file mode 100644 index 0000000000000000000000000000000000000000..9e5e582592548ac8ed30137f5f1f34621bcbf1c4 --- /dev/null +++ b/docs/colossalai/colossalai.fx.passes.split_module.rst @@ -0,0 +1,5 @@ +colossalai.fx.passes.split\_module +================================== + +.. automodule:: colossalai.fx.passes.split_module + :members: diff --git a/docs/colossalai/colossalai.fx.passes.utils.rst b/docs/colossalai/colossalai.fx.passes.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..4afd9256322b10a17731825622e5a16e0d24eafe --- /dev/null +++ b/docs/colossalai/colossalai.fx.passes.utils.rst @@ -0,0 +1,5 @@ +colossalai.fx.passes.utils +========================== + +.. automodule:: colossalai.fx.passes.utils + :members: diff --git a/docs/colossalai/colossalai.fx.proxy.rst b/docs/colossalai/colossalai.fx.proxy.rst new file mode 100644 index 0000000000000000000000000000000000000000..4b92da41c794893edf4a37db32b2c3d618c82f34 --- /dev/null +++ b/docs/colossalai/colossalai.fx.proxy.rst @@ -0,0 +1,5 @@ +colossalai.fx.proxy +=================== + +.. automodule:: colossalai.fx.proxy + :members: diff --git a/docs/colossalai/colossalai.fx.rst b/docs/colossalai/colossalai.fx.rst new file mode 100644 index 0000000000000000000000000000000000000000..778d642c3a11b1e5a09b301130dd29968ca44ef3 --- /dev/null +++ b/docs/colossalai/colossalai.fx.rst @@ -0,0 +1,17 @@ +colossalai.fx +============= + +.. automodule:: colossalai.fx + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.fx.passes + colossalai.fx.tracer + + +.. toctree:: + :maxdepth: 2 + + colossalai.fx.proxy diff --git a/docs/colossalai/colossalai.fx.tracer.rst b/docs/colossalai/colossalai.fx.tracer.rst new file mode 100644 index 0000000000000000000000000000000000000000..d2f743d67d55f8f7f7e7074eb6d39d61cb1ea167 --- /dev/null +++ b/docs/colossalai/colossalai.fx.tracer.rst @@ -0,0 +1,11 @@ +colossalai.fx.tracer +==================== + +.. automodule:: colossalai.fx.tracer + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.fx.tracer.tracer diff --git a/docs/colossalai/colossalai.fx.tracer.tracer.rst b/docs/colossalai/colossalai.fx.tracer.tracer.rst new file mode 100644 index 0000000000000000000000000000000000000000..83b98bafd8251e4e8bcdaba2c0e77c0ddabf3c30 --- /dev/null +++ b/docs/colossalai/colossalai.fx.tracer.tracer.rst @@ -0,0 +1,5 @@ +colossalai.fx.tracer.tracer +=========================== + +.. automodule:: colossalai.fx.tracer.tracer + :members: diff --git a/docs/colossalai/colossalai.gemini.chunk.rst b/docs/colossalai/colossalai.gemini.chunk.rst new file mode 100644 index 0000000000000000000000000000000000000000..9fe1c2b415d6b73d96702796697f43834d95b7c6 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.chunk.rst @@ -0,0 +1,5 @@ +colossalai.gemini.chunk +======================= + +.. automodule:: colossalai.gemini.chunk + :members: diff --git a/docs/colossalai/colossalai.gemini.chunk_mgr.rst b/docs/colossalai/colossalai.gemini.chunk_mgr.rst new file mode 100644 index 0000000000000000000000000000000000000000..acb554faf31942a0cd89fcfe5368250cdd49bc55 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.chunk_mgr.rst @@ -0,0 +1,5 @@ +colossalai.gemini.chunk\_mgr +============================ + +.. automodule:: colossalai.gemini.chunk_mgr + :members: diff --git a/docs/colossalai/colossalai.gemini.gemini_context.rst b/docs/colossalai/colossalai.gemini.gemini_context.rst new file mode 100644 index 0000000000000000000000000000000000000000..be48840622533ccf3a22835f39c8cc87be1bbe5e --- /dev/null +++ b/docs/colossalai/colossalai.gemini.gemini_context.rst @@ -0,0 +1,5 @@ +colossalai.gemini.gemini\_context +================================= + +.. automodule:: colossalai.gemini.gemini_context + :members: diff --git a/docs/colossalai/colossalai.gemini.gemini_mgr.rst b/docs/colossalai/colossalai.gemini.gemini_mgr.rst new file mode 100644 index 0000000000000000000000000000000000000000..5d7f944f7a5657df0f2e761608313633aa796c9b --- /dev/null +++ b/docs/colossalai/colossalai.gemini.gemini_mgr.rst @@ -0,0 +1,5 @@ +colossalai.gemini.gemini\_mgr +============================= + +.. automodule:: colossalai.gemini.gemini_mgr + :members: diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.memory_monitor.rst b/docs/colossalai/colossalai.gemini.memory_tracer.memory_monitor.rst new file mode 100644 index 0000000000000000000000000000000000000000..e8088a609f34ba22b0e526abe70bef223151b71a --- /dev/null +++ b/docs/colossalai/colossalai.gemini.memory_tracer.memory_monitor.rst @@ -0,0 +1,5 @@ +colossalai.gemini.memory\_tracer.memory\_monitor +================================================ + +.. automodule:: colossalai.gemini.memory_tracer.memory_monitor + :members: diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.memstats_collector.rst b/docs/colossalai/colossalai.gemini.memory_tracer.memstats_collector.rst new file mode 100644 index 0000000000000000000000000000000000000000..e2682220c27b54c388acdfee3963fbecaff8e3c4 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.memory_tracer.memstats_collector.rst @@ -0,0 +1,5 @@ +colossalai.gemini.memory\_tracer.memstats\_collector +==================================================== + +.. automodule:: colossalai.gemini.memory_tracer.memstats_collector + :members: diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.model_data_memtracer.rst b/docs/colossalai/colossalai.gemini.memory_tracer.model_data_memtracer.rst new file mode 100644 index 0000000000000000000000000000000000000000..ccdfe6682c3fd973cd1206e112b3ea6d35bc7258 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.memory_tracer.model_data_memtracer.rst @@ -0,0 +1,5 @@ +colossalai.gemini.memory\_tracer.model\_data\_memtracer +======================================================= + +.. automodule:: colossalai.gemini.memory_tracer.model_data_memtracer + :members: diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.rst b/docs/colossalai/colossalai.gemini.memory_tracer.rst new file mode 100644 index 0000000000000000000000000000000000000000..f3d9c4d76dd8ba00892829b45674a3be3cd09aa2 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.memory_tracer.rst @@ -0,0 +1,13 @@ +colossalai.gemini.memory\_tracer +================================ + +.. automodule:: colossalai.gemini.memory_tracer + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.gemini.memory_tracer.memory_monitor + colossalai.gemini.memory_tracer.memstats_collector + colossalai.gemini.memory_tracer.model_data_memtracer diff --git a/docs/colossalai/colossalai.gemini.ophooks.rst b/docs/colossalai/colossalai.gemini.ophooks.rst new file mode 100644 index 0000000000000000000000000000000000000000..af87ab568ac02237f641c691a107665f4c2b5448 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.ophooks.rst @@ -0,0 +1,11 @@ +colossalai.gemini.ophooks +========================= + +.. automodule:: colossalai.gemini.ophooks + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.gemini.ophooks.utils diff --git a/docs/colossalai/colossalai.gemini.ophooks.utils.rst b/docs/colossalai/colossalai.gemini.ophooks.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..5c5917047f44a504fddce31234f6c8190a222655 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.ophooks.utils.rst @@ -0,0 +1,5 @@ +colossalai.gemini.ophooks.utils +=============================== + +.. automodule:: colossalai.gemini.ophooks.utils + :members: diff --git a/docs/colossalai/colossalai.gemini.paramhooks.rst b/docs/colossalai/colossalai.gemini.paramhooks.rst new file mode 100644 index 0000000000000000000000000000000000000000..28a823d4e69cf719820300e872508b3e5f7d521f --- /dev/null +++ b/docs/colossalai/colossalai.gemini.paramhooks.rst @@ -0,0 +1,5 @@ +colossalai.gemini.paramhooks +============================ + +.. automodule:: colossalai.gemini.paramhooks + :members: diff --git a/docs/colossalai/colossalai.gemini.placement_policy.rst b/docs/colossalai/colossalai.gemini.placement_policy.rst new file mode 100644 index 0000000000000000000000000000000000000000..9de0ed52371b22c61491049165cd59ca39c9c670 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.placement_policy.rst @@ -0,0 +1,5 @@ +colossalai.gemini.placement\_policy +=================================== + +.. automodule:: colossalai.gemini.placement_policy + :members: diff --git a/docs/colossalai/colossalai.gemini.rst b/docs/colossalai/colossalai.gemini.rst new file mode 100644 index 0000000000000000000000000000000000000000..4f6efe386521f7e332451c3a1cf87eb7739c0ff5 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.rst @@ -0,0 +1,27 @@ +colossalai.gemini +================= + +.. automodule:: colossalai.gemini + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.gemini.memory_tracer + colossalai.gemini.ophooks + colossalai.gemini.paramhooks + + +.. toctree:: + :maxdepth: 2 + + colossalai.gemini.chunk + colossalai.gemini.chunk_mgr + colossalai.gemini.gemini_context + colossalai.gemini.gemini_mgr + colossalai.gemini.placement_policy + colossalai.gemini.stateful_tensor + colossalai.gemini.stateful_tensor_container + colossalai.gemini.stateful_tensor_mgr + colossalai.gemini.tensor_placement_policy + colossalai.gemini.tensor_utils diff --git a/docs/colossalai/colossalai.gemini.stateful_tensor.rst b/docs/colossalai/colossalai.gemini.stateful_tensor.rst new file mode 100644 index 0000000000000000000000000000000000000000..02d526d1b4c8ad71ecf5dd81e4933338b9b2072d --- /dev/null +++ b/docs/colossalai/colossalai.gemini.stateful_tensor.rst @@ -0,0 +1,5 @@ +colossalai.gemini.stateful\_tensor +================================== + +.. automodule:: colossalai.gemini.stateful_tensor + :members: diff --git a/docs/colossalai/colossalai.gemini.stateful_tensor_container.rst b/docs/colossalai/colossalai.gemini.stateful_tensor_container.rst new file mode 100644 index 0000000000000000000000000000000000000000..be56c2aa8ed20d781238a8e07a1e1243b04fe315 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.stateful_tensor_container.rst @@ -0,0 +1,5 @@ +colossalai.gemini.stateful\_tensor\_container +============================================= + +.. automodule:: colossalai.gemini.stateful_tensor_container + :members: diff --git a/docs/colossalai/colossalai.gemini.stateful_tensor_mgr.rst b/docs/colossalai/colossalai.gemini.stateful_tensor_mgr.rst new file mode 100644 index 0000000000000000000000000000000000000000..3456192bd735381ff00e7d2b35464b7e750a96f0 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.stateful_tensor_mgr.rst @@ -0,0 +1,5 @@ +colossalai.gemini.stateful\_tensor\_mgr +======================================= + +.. automodule:: colossalai.gemini.stateful_tensor_mgr + :members: diff --git a/docs/colossalai/colossalai.gemini.tensor_placement_policy.rst b/docs/colossalai/colossalai.gemini.tensor_placement_policy.rst new file mode 100644 index 0000000000000000000000000000000000000000..81dcac33904873ff5dc3cc52824dca89733e3ac3 --- /dev/null +++ b/docs/colossalai/colossalai.gemini.tensor_placement_policy.rst @@ -0,0 +1,5 @@ +colossalai.gemini.tensor\_placement\_policy +=========================================== + +.. automodule:: colossalai.gemini.tensor_placement_policy + :members: diff --git a/docs/colossalai/colossalai.gemini.tensor_utils.rst b/docs/colossalai/colossalai.gemini.tensor_utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..385baf4b50bbcc1acf10a4fd980b31ddde2140bb --- /dev/null +++ b/docs/colossalai/colossalai.gemini.tensor_utils.rst @@ -0,0 +1,5 @@ +colossalai.gemini.tensor\_utils +=============================== + +.. automodule:: colossalai.gemini.tensor_utils + :members: diff --git a/docs/colossalai/colossalai.global_variables.rst b/docs/colossalai/colossalai.global_variables.rst new file mode 100644 index 0000000000000000000000000000000000000000..1900c88351ff2a1435c4a1d57c86dbe6ee2cbdba --- /dev/null +++ b/docs/colossalai/colossalai.global_variables.rst @@ -0,0 +1,5 @@ +colossalai.global\_variables +============================ + +.. automodule:: colossalai.global_variables + :members: diff --git a/docs/colossalai/colossalai.initialize.rst b/docs/colossalai/colossalai.initialize.rst new file mode 100644 index 0000000000000000000000000000000000000000..d3f65076a795876a34d3bcbcc03a4b4b96a28e79 --- /dev/null +++ b/docs/colossalai/colossalai.initialize.rst @@ -0,0 +1,5 @@ +colossalai.initialize +===================== + +.. automodule:: colossalai.initialize + :members: diff --git a/docs/colossalai/colossalai.kernel.cuda_native.layer_norm.rst b/docs/colossalai/colossalai.kernel.cuda_native.layer_norm.rst new file mode 100644 index 0000000000000000000000000000000000000000..b8bff51bef34d1cd5d515f1fb36da8d5634af13c --- /dev/null +++ b/docs/colossalai/colossalai.kernel.cuda_native.layer_norm.rst @@ -0,0 +1,5 @@ +colossalai.kernel.cuda\_native.layer\_norm +========================================== + +.. automodule:: colossalai.kernel.cuda_native.layer_norm + :members: diff --git a/docs/colossalai/colossalai.kernel.cuda_native.multihead_attention.rst b/docs/colossalai/colossalai.kernel.cuda_native.multihead_attention.rst new file mode 100644 index 0000000000000000000000000000000000000000..de7577d195cd70de7054af3d68d1b33275f90f5c --- /dev/null +++ b/docs/colossalai/colossalai.kernel.cuda_native.multihead_attention.rst @@ -0,0 +1,5 @@ +colossalai.kernel.cuda\_native.multihead\_attention +=================================================== + +.. automodule:: colossalai.kernel.cuda_native.multihead_attention + :members: diff --git a/docs/colossalai/colossalai.kernel.cuda_native.rst b/docs/colossalai/colossalai.kernel.cuda_native.rst new file mode 100644 index 0000000000000000000000000000000000000000..d88e4cfdb761f37266ce9ee2433ef440a8923105 --- /dev/null +++ b/docs/colossalai/colossalai.kernel.cuda_native.rst @@ -0,0 +1,13 @@ +colossalai.kernel.cuda\_native +============================== + +.. automodule:: colossalai.kernel.cuda_native + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.kernel.cuda_native.layer_norm + colossalai.kernel.cuda_native.multihead_attention + colossalai.kernel.cuda_native.scaled_softmax diff --git a/docs/colossalai/colossalai.kernel.cuda_native.scaled_softmax.rst b/docs/colossalai/colossalai.kernel.cuda_native.scaled_softmax.rst new file mode 100644 index 0000000000000000000000000000000000000000..474fcd3349bd79c9db707ef462cd31ac3ccde249 --- /dev/null +++ b/docs/colossalai/colossalai.kernel.cuda_native.scaled_softmax.rst @@ -0,0 +1,5 @@ +colossalai.kernel.cuda\_native.scaled\_softmax +============================================== + +.. automodule:: colossalai.kernel.cuda_native.scaled_softmax + :members: diff --git a/docs/colossalai/colossalai.kernel.jit.bias_dropout_add.rst b/docs/colossalai/colossalai.kernel.jit.bias_dropout_add.rst new file mode 100644 index 0000000000000000000000000000000000000000..d61550928bc8742060ce912cf55725d29a3168e4 --- /dev/null +++ b/docs/colossalai/colossalai.kernel.jit.bias_dropout_add.rst @@ -0,0 +1,5 @@ +colossalai.kernel.jit.bias\_dropout\_add +======================================== + +.. automodule:: colossalai.kernel.jit.bias_dropout_add + :members: diff --git a/docs/colossalai/colossalai.kernel.jit.bias_gelu.rst b/docs/colossalai/colossalai.kernel.jit.bias_gelu.rst new file mode 100644 index 0000000000000000000000000000000000000000..7db184b4ce3bd96d692debe06f86ef7d000318ed --- /dev/null +++ b/docs/colossalai/colossalai.kernel.jit.bias_gelu.rst @@ -0,0 +1,5 @@ +colossalai.kernel.jit.bias\_gelu +================================ + +.. automodule:: colossalai.kernel.jit.bias_gelu + :members: diff --git a/docs/colossalai/colossalai.kernel.jit.option.rst b/docs/colossalai/colossalai.kernel.jit.option.rst new file mode 100644 index 0000000000000000000000000000000000000000..15ebfc83aa7744444d2b60471264cb7d415b26dc --- /dev/null +++ b/docs/colossalai/colossalai.kernel.jit.option.rst @@ -0,0 +1,5 @@ +colossalai.kernel.jit.option +============================ + +.. automodule:: colossalai.kernel.jit.option + :members: diff --git a/docs/colossalai/colossalai.kernel.jit.rst b/docs/colossalai/colossalai.kernel.jit.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b2f728d34d55d326979827a6cb3e23022ebb99e --- /dev/null +++ b/docs/colossalai/colossalai.kernel.jit.rst @@ -0,0 +1,13 @@ +colossalai.kernel.jit +===================== + +.. automodule:: colossalai.kernel.jit + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.kernel.jit.bias_dropout_add + colossalai.kernel.jit.bias_gelu + colossalai.kernel.jit.option diff --git a/docs/colossalai/colossalai.kernel.rst b/docs/colossalai/colossalai.kernel.rst new file mode 100644 index 0000000000000000000000000000000000000000..dcbac8c1de76167ee5dcaccc8dde8c3805759ae1 --- /dev/null +++ b/docs/colossalai/colossalai.kernel.rst @@ -0,0 +1,11 @@ +colossalai.kernel +================= + +.. automodule:: colossalai.kernel + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.kernel.cuda_native + colossalai.kernel.jit diff --git a/docs/colossalai/colossalai.logging.logger.rst b/docs/colossalai/colossalai.logging.logger.rst new file mode 100644 index 0000000000000000000000000000000000000000..047deb8a1d19704d853a490eaff90b32051148d7 --- /dev/null +++ b/docs/colossalai/colossalai.logging.logger.rst @@ -0,0 +1,5 @@ +colossalai.logging.logger +========================= + +.. automodule:: colossalai.logging.logger + :members: diff --git a/docs/colossalai/colossalai.logging.rst b/docs/colossalai/colossalai.logging.rst new file mode 100644 index 0000000000000000000000000000000000000000..bc593fc81bf47f5221c835363049b11f0a3a6ca9 --- /dev/null +++ b/docs/colossalai/colossalai.logging.rst @@ -0,0 +1,11 @@ +colossalai.logging +================== + +.. automodule:: colossalai.logging + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.logging.logger diff --git a/docs/colossalai/colossalai.nn.graph.graph_node.rst b/docs/colossalai/colossalai.nn.graph.graph_node.rst new file mode 100644 index 0000000000000000000000000000000000000000..335ecfe620feac9b0ecbb68edffa5fc70b0629a8 --- /dev/null +++ b/docs/colossalai/colossalai.nn.graph.graph_node.rst @@ -0,0 +1,5 @@ +colossalai.nn.graph.graph\_node +=============================== + +.. automodule:: colossalai.nn.graph.graph_node + :members: diff --git a/docs/colossalai/colossalai.nn.graph.rst b/docs/colossalai/colossalai.nn.graph.rst new file mode 100644 index 0000000000000000000000000000000000000000..4510b3374f2aafd45aab22fb93e81a1864150777 --- /dev/null +++ b/docs/colossalai/colossalai.nn.graph.rst @@ -0,0 +1,12 @@ +colossalai.nn.graph +=================== + +.. automodule:: colossalai.nn.graph + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.graph.graph_node + colossalai.nn.graph.utils diff --git a/docs/colossalai/colossalai.nn.graph.utils.rst b/docs/colossalai/colossalai.nn.graph.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..866a93cd92018da428ee3390c34d4f29a66c9bcc --- /dev/null +++ b/docs/colossalai/colossalai.nn.graph.utils.rst @@ -0,0 +1,5 @@ +colossalai.nn.graph.utils +========================= + +.. automodule:: colossalai.nn.graph.utils + :members: diff --git a/docs/colossalai/colossalai.nn.init.rst b/docs/colossalai/colossalai.nn.init.rst new file mode 100644 index 0000000000000000000000000000000000000000..d0ab993126d5b3b63b7d0aeab86031d716f1b301 --- /dev/null +++ b/docs/colossalai/colossalai.nn.init.rst @@ -0,0 +1,5 @@ +colossalai.nn.init +================== + +.. automodule:: colossalai.nn.init + :members: diff --git a/docs/colossalai/colossalai.nn.layer.base_layer.rst b/docs/colossalai/colossalai.nn.layer.base_layer.rst new file mode 100644 index 0000000000000000000000000000000000000000..c2a22f04d3f37c22b54aaacaf89b962e947d7d80 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.base_layer.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.base\_layer +=============================== + +.. automodule:: colossalai.nn.layer.base_layer + :members: diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.dropout.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.dropout.rst new file mode 100644 index 0000000000000000000000000000000000000000..ec1dfd395f1709ddb696b3809c39a19d7a5efd13 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.colossalai_layer.dropout.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.colossalai\_layer.dropout +============================================= + +.. automodule:: colossalai.nn.layer.colossalai_layer.dropout + :members: diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.embedding.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.embedding.rst new file mode 100644 index 0000000000000000000000000000000000000000..8438b3a077879e7722e3eb47414ef06adc3d30c3 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.colossalai_layer.embedding.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.colossalai\_layer.embedding +=============================================== + +.. automodule:: colossalai.nn.layer.colossalai_layer.embedding + :members: diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.linear.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.linear.rst new file mode 100644 index 0000000000000000000000000000000000000000..3213282549eaca72275d6a300813d937f3348526 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.colossalai_layer.linear.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.colossalai\_layer.linear +============================================ + +.. automodule:: colossalai.nn.layer.colossalai_layer.linear + :members: diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.normalization.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.normalization.rst new file mode 100644 index 0000000000000000000000000000000000000000..f94dd27b86e43f31719f54d90c033a2c3bd85e6e --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.colossalai_layer.normalization.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.colossalai\_layer.normalization +=================================================== + +.. automodule:: colossalai.nn.layer.colossalai_layer.normalization + :members: diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.rst new file mode 100644 index 0000000000000000000000000000000000000000..0f685e6c2dc3a463f237d19764e9c8297360c5a2 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.colossalai_layer.rst @@ -0,0 +1,14 @@ +colossalai.nn.layer.colossalai\_layer +===================================== + +.. automodule:: colossalai.nn.layer.colossalai_layer + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.colossalai_layer.dropout + colossalai.nn.layer.colossalai_layer.embedding + colossalai.nn.layer.colossalai_layer.linear + colossalai.nn.layer.colossalai_layer.normalization diff --git a/docs/colossalai/colossalai.nn.layer.moe.experts.rst b/docs/colossalai/colossalai.nn.layer.moe.experts.rst new file mode 100644 index 0000000000000000000000000000000000000000..c05e763d572363d6d6c38c6772b4b7ce3a191173 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.moe.experts.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.moe.experts +=============================== + +.. automodule:: colossalai.nn.layer.moe.experts + :members: diff --git a/docs/colossalai/colossalai.nn.layer.moe.layers.rst b/docs/colossalai/colossalai.nn.layer.moe.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..d109d47b8174375f561b717236bc9020c2e3675d --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.moe.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.moe.layers +============================== + +.. automodule:: colossalai.nn.layer.moe.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.moe.rst b/docs/colossalai/colossalai.nn.layer.moe.rst new file mode 100644 index 0000000000000000000000000000000000000000..f3106b98d4051a26c9d5b06aa9168d32dfe4a49b --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.moe.rst @@ -0,0 +1,13 @@ +colossalai.nn.layer.moe +======================= + +.. automodule:: colossalai.nn.layer.moe + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.moe.experts + colossalai.nn.layer.moe.layers + colossalai.nn.layer.moe.utils diff --git a/docs/colossalai/colossalai.nn.layer.moe.utils.rst b/docs/colossalai/colossalai.nn.layer.moe.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..fc085d136bb4ee88c6b47105ada2d53c3f38c423 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.moe.utils.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.moe.utils +============================= + +.. automodule:: colossalai.nn.layer.moe.utils + :members: diff --git a/docs/colossalai/colossalai.nn.layer.parallel_1d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_1d.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..380f6bf8d134d55482902d9ea6c1b18f09955bb0 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_1d.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.parallel\_1d.layers +======================================= + +.. automodule:: colossalai.nn.layer.parallel_1d.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.parallel_1d.rst b/docs/colossalai/colossalai.nn.layer.parallel_1d.rst new file mode 100644 index 0000000000000000000000000000000000000000..3a8ed620672189e6178925d46a21753b7d4f79e3 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_1d.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.parallel\_1d +================================ + +.. automodule:: colossalai.nn.layer.parallel_1d + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.parallel_1d.layers diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_2d.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..b64d402bdf3e608fc522f0567319a37b4bd35b2a --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_2d.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.parallel\_2d.layers +======================================= + +.. automodule:: colossalai.nn.layer.parallel_2d.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2d.rst b/docs/colossalai/colossalai.nn.layer.parallel_2d.rst new file mode 100644 index 0000000000000000000000000000000000000000..f5ad41a1b450ae721809fe912a762006bd77e8ad --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_2d.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.parallel\_2d +================================ + +.. automodule:: colossalai.nn.layer.parallel_2d + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.parallel_2d.layers diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2p5d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_2p5d.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..ebc99d56ccdc58675d99e6d25c6e5867f9144e4f --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_2p5d.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.parallel\_2p5d.layers +========================================= + +.. automodule:: colossalai.nn.layer.parallel_2p5d.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2p5d.rst b/docs/colossalai/colossalai.nn.layer.parallel_2p5d.rst new file mode 100644 index 0000000000000000000000000000000000000000..5869bdee9928d1843b171011fb7f9e86169d6fe8 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_2p5d.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.parallel\_2p5d +================================== + +.. automodule:: colossalai.nn.layer.parallel_2p5d + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.parallel_2p5d.layers diff --git a/docs/colossalai/colossalai.nn.layer.parallel_3d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_3d.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..a1702f1fcf627cd5d49996af5d65dd1388007e7a --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_3d.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.parallel\_3d.layers +======================================= + +.. automodule:: colossalai.nn.layer.parallel_3d.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.parallel_3d.rst b/docs/colossalai/colossalai.nn.layer.parallel_3d.rst new file mode 100644 index 0000000000000000000000000000000000000000..bb55a63e507d60e010c34c2edcc7a90fd21b0dc1 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_3d.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.parallel\_3d +================================ + +.. automodule:: colossalai.nn.layer.parallel_3d + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.parallel_3d.layers diff --git a/docs/colossalai/colossalai.nn.layer.parallel_sequence.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_sequence.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..54929d2e71690bda6a390e8ff64ad61f4f077e25 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_sequence.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.parallel\_sequence.layers +============================================= + +.. automodule:: colossalai.nn.layer.parallel_sequence.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.parallel_sequence.rst b/docs/colossalai/colossalai.nn.layer.parallel_sequence.rst new file mode 100644 index 0000000000000000000000000000000000000000..24e8941d4ec4e6c64fb30c377055dbe8176283b2 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.parallel_sequence.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.parallel\_sequence +====================================== + +.. automodule:: colossalai.nn.layer.parallel_sequence + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.parallel_sequence.layers diff --git a/docs/colossalai/colossalai.nn.layer.rst b/docs/colossalai/colossalai.nn.layer.rst new file mode 100644 index 0000000000000000000000000000000000000000..32a93128f2a40431768048f2394f8427955eca7d --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.rst @@ -0,0 +1,25 @@ +colossalai.nn.layer +=================== + +.. automodule:: colossalai.nn.layer + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.colossalai_layer + colossalai.nn.layer.moe + colossalai.nn.layer.parallel_1d + colossalai.nn.layer.parallel_2d + colossalai.nn.layer.parallel_2p5d + colossalai.nn.layer.parallel_3d + colossalai.nn.layer.parallel_sequence + colossalai.nn.layer.utils + colossalai.nn.layer.vanilla + colossalai.nn.layer.wrapper + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.base_layer diff --git a/docs/colossalai/colossalai.nn.layer.utils.common.rst b/docs/colossalai/colossalai.nn.layer.utils.common.rst new file mode 100644 index 0000000000000000000000000000000000000000..6a552830f8f56652fd7f721087dfcefc278cffef --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.utils.common.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.utils.common +================================ + +.. automodule:: colossalai.nn.layer.utils.common + :members: diff --git a/docs/colossalai/colossalai.nn.layer.utils.rst b/docs/colossalai/colossalai.nn.layer.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..16c3d718286a10211684de54b36a474defef771d --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.utils.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.utils +========================= + +.. automodule:: colossalai.nn.layer.utils + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.utils.common diff --git a/docs/colossalai/colossalai.nn.layer.vanilla.layers.rst b/docs/colossalai/colossalai.nn.layer.vanilla.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..f993b1f50e5bb5a57eb4699175aab3964cc68647 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.vanilla.layers.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.vanilla.layers +================================== + +.. automodule:: colossalai.nn.layer.vanilla.layers + :members: diff --git a/docs/colossalai/colossalai.nn.layer.vanilla.rst b/docs/colossalai/colossalai.nn.layer.vanilla.rst new file mode 100644 index 0000000000000000000000000000000000000000..fe1ea5c6c53e4ead032a0e9707059454804dfd7d --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.vanilla.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.vanilla +=========================== + +.. automodule:: colossalai.nn.layer.vanilla + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.vanilla.layers diff --git a/docs/colossalai/colossalai.nn.layer.wrapper.pipeline_wrapper.rst b/docs/colossalai/colossalai.nn.layer.wrapper.pipeline_wrapper.rst new file mode 100644 index 0000000000000000000000000000000000000000..e5648873d34b9c618700047ba7f3d08fd007dcfe --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.wrapper.pipeline_wrapper.rst @@ -0,0 +1,5 @@ +colossalai.nn.layer.wrapper.pipeline\_wrapper +============================================= + +.. automodule:: colossalai.nn.layer.wrapper.pipeline_wrapper + :members: diff --git a/docs/colossalai/colossalai.nn.layer.wrapper.rst b/docs/colossalai/colossalai.nn.layer.wrapper.rst new file mode 100644 index 0000000000000000000000000000000000000000..761bf843af365c05b296a691c87017916bdb7e26 --- /dev/null +++ b/docs/colossalai/colossalai.nn.layer.wrapper.rst @@ -0,0 +1,11 @@ +colossalai.nn.layer.wrapper +=========================== + +.. automodule:: colossalai.nn.layer.wrapper + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.layer.wrapper.pipeline_wrapper diff --git a/docs/colossalai/colossalai.nn.loss.loss_1d.rst b/docs/colossalai/colossalai.nn.loss.loss_1d.rst new file mode 100644 index 0000000000000000000000000000000000000000..d9ac2e67d317df6b5b62ed50cb5f22d09cd1c381 --- /dev/null +++ b/docs/colossalai/colossalai.nn.loss.loss_1d.rst @@ -0,0 +1,5 @@ +colossalai.nn.loss.loss\_1d +=========================== + +.. automodule:: colossalai.nn.loss.loss_1d + :members: diff --git a/docs/colossalai/colossalai.nn.loss.loss_2d.rst b/docs/colossalai/colossalai.nn.loss.loss_2d.rst new file mode 100644 index 0000000000000000000000000000000000000000..14d1585e3e0fe42943b40b75280fab2bf5993300 --- /dev/null +++ b/docs/colossalai/colossalai.nn.loss.loss_2d.rst @@ -0,0 +1,5 @@ +colossalai.nn.loss.loss\_2d +=========================== + +.. automodule:: colossalai.nn.loss.loss_2d + :members: diff --git a/docs/colossalai/colossalai.nn.loss.loss_2p5d.rst b/docs/colossalai/colossalai.nn.loss.loss_2p5d.rst new file mode 100644 index 0000000000000000000000000000000000000000..fc3714da36301a65e88bf1856cf74375580cce19 --- /dev/null +++ b/docs/colossalai/colossalai.nn.loss.loss_2p5d.rst @@ -0,0 +1,5 @@ +colossalai.nn.loss.loss\_2p5d +============================= + +.. automodule:: colossalai.nn.loss.loss_2p5d + :members: diff --git a/docs/colossalai/colossalai.nn.loss.loss_3d.rst b/docs/colossalai/colossalai.nn.loss.loss_3d.rst new file mode 100644 index 0000000000000000000000000000000000000000..a593324fb4f16741383477f3b44233e54c354859 --- /dev/null +++ b/docs/colossalai/colossalai.nn.loss.loss_3d.rst @@ -0,0 +1,5 @@ +colossalai.nn.loss.loss\_3d +=========================== + +.. automodule:: colossalai.nn.loss.loss_3d + :members: diff --git a/docs/colossalai/colossalai.nn.loss.loss_moe.rst b/docs/colossalai/colossalai.nn.loss.loss_moe.rst new file mode 100644 index 0000000000000000000000000000000000000000..ef2851ace83a0fb1c98464c761f0bf6d1063234f --- /dev/null +++ b/docs/colossalai/colossalai.nn.loss.loss_moe.rst @@ -0,0 +1,5 @@ +colossalai.nn.loss.loss\_moe +============================ + +.. automodule:: colossalai.nn.loss.loss_moe + :members: diff --git a/docs/colossalai/colossalai.nn.loss.rst b/docs/colossalai/colossalai.nn.loss.rst new file mode 100644 index 0000000000000000000000000000000000000000..5df7d1ae37701678355848d4877841544c9109cb --- /dev/null +++ b/docs/colossalai/colossalai.nn.loss.rst @@ -0,0 +1,15 @@ +colossalai.nn.loss +================== + +.. automodule:: colossalai.nn.loss + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.loss.loss_1d + colossalai.nn.loss.loss_2d + colossalai.nn.loss.loss_2p5d + colossalai.nn.loss.loss_3d + colossalai.nn.loss.loss_moe diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.cosine.rst b/docs/colossalai/colossalai.nn.lr_scheduler.cosine.rst new file mode 100644 index 0000000000000000000000000000000000000000..a7c636ad3a364ed85e105df4102960d081edb434 --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.cosine.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.cosine +================================== + +.. automodule:: colossalai.nn.lr_scheduler.cosine + :members: diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.delayed.rst b/docs/colossalai/colossalai.nn.lr_scheduler.delayed.rst new file mode 100644 index 0000000000000000000000000000000000000000..2a86c4b2a20c4e7f4db1c719d45db68ca475eea5 --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.delayed.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.delayed +=================================== + +.. automodule:: colossalai.nn.lr_scheduler.delayed + :members: diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.linear.rst b/docs/colossalai/colossalai.nn.lr_scheduler.linear.rst new file mode 100644 index 0000000000000000000000000000000000000000..5e917edc2faf84b49f2025bc2aee4cae8b5fd422 --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.linear.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.linear +================================== + +.. automodule:: colossalai.nn.lr_scheduler.linear + :members: diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.multistep.rst b/docs/colossalai/colossalai.nn.lr_scheduler.multistep.rst new file mode 100644 index 0000000000000000000000000000000000000000..4248a638637543a0196616fa2addc17cc79a2f6d --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.multistep.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.multistep +===================================== + +.. automodule:: colossalai.nn.lr_scheduler.multistep + :members: diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.onecycle.rst b/docs/colossalai/colossalai.nn.lr_scheduler.onecycle.rst new file mode 100644 index 0000000000000000000000000000000000000000..7f2fd47586fea3ef2114b23654c56f64827b49bd --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.onecycle.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.onecycle +==================================== + +.. automodule:: colossalai.nn.lr_scheduler.onecycle + :members: diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.poly.rst b/docs/colossalai/colossalai.nn.lr_scheduler.poly.rst new file mode 100644 index 0000000000000000000000000000000000000000..c1618812aa0c34b31deb3276a1c671ba142c2b80 --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.poly.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.poly +================================ + +.. automodule:: colossalai.nn.lr_scheduler.poly + :members: diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.rst b/docs/colossalai/colossalai.nn.lr_scheduler.rst new file mode 100644 index 0000000000000000000000000000000000000000..427a3ee4529e45ec0adff7707a8f90e6650ec5ce --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.rst @@ -0,0 +1,17 @@ +colossalai.nn.lr\_scheduler +=========================== + +.. automodule:: colossalai.nn.lr_scheduler + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.lr_scheduler.cosine + colossalai.nn.lr_scheduler.delayed + colossalai.nn.lr_scheduler.linear + colossalai.nn.lr_scheduler.multistep + colossalai.nn.lr_scheduler.onecycle + colossalai.nn.lr_scheduler.poly + colossalai.nn.lr_scheduler.torch diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.torch.rst b/docs/colossalai/colossalai.nn.lr_scheduler.torch.rst new file mode 100644 index 0000000000000000000000000000000000000000..f8d552bf1d62d069923ae713f159d0b5eeefd10a --- /dev/null +++ b/docs/colossalai/colossalai.nn.lr_scheduler.torch.rst @@ -0,0 +1,5 @@ +colossalai.nn.lr\_scheduler.torch +================================= + +.. automodule:: colossalai.nn.lr_scheduler.torch + :members: diff --git a/docs/colossalai/colossalai.nn.metric.accuracy_2d.rst b/docs/colossalai/colossalai.nn.metric.accuracy_2d.rst new file mode 100644 index 0000000000000000000000000000000000000000..63bcb834976384874049930eb21da742a5d1835b --- /dev/null +++ b/docs/colossalai/colossalai.nn.metric.accuracy_2d.rst @@ -0,0 +1,5 @@ +colossalai.nn.metric.accuracy\_2d +================================= + +.. automodule:: colossalai.nn.metric.accuracy_2d + :members: diff --git a/docs/colossalai/colossalai.nn.metric.accuracy_2p5d.rst b/docs/colossalai/colossalai.nn.metric.accuracy_2p5d.rst new file mode 100644 index 0000000000000000000000000000000000000000..dd4358fbff72eb5df642168cc4674a85de041387 --- /dev/null +++ b/docs/colossalai/colossalai.nn.metric.accuracy_2p5d.rst @@ -0,0 +1,5 @@ +colossalai.nn.metric.accuracy\_2p5d +=================================== + +.. automodule:: colossalai.nn.metric.accuracy_2p5d + :members: diff --git a/docs/colossalai/colossalai.nn.metric.accuracy_3d.rst b/docs/colossalai/colossalai.nn.metric.accuracy_3d.rst new file mode 100644 index 0000000000000000000000000000000000000000..95143444b945e35c46fa94528e4e7bddf27a19dd --- /dev/null +++ b/docs/colossalai/colossalai.nn.metric.accuracy_3d.rst @@ -0,0 +1,5 @@ +colossalai.nn.metric.accuracy\_3d +================================= + +.. automodule:: colossalai.nn.metric.accuracy_3d + :members: diff --git a/docs/colossalai/colossalai.nn.metric.rst b/docs/colossalai/colossalai.nn.metric.rst new file mode 100644 index 0000000000000000000000000000000000000000..28f5568eb84696011e46fdfaad59aed00964b094 --- /dev/null +++ b/docs/colossalai/colossalai.nn.metric.rst @@ -0,0 +1,13 @@ +colossalai.nn.metric +==================== + +.. automodule:: colossalai.nn.metric + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.metric.accuracy_2d + colossalai.nn.metric.accuracy_2p5d + colossalai.nn.metric.accuracy_3d diff --git a/docs/colossalai/colossalai.nn.optimizer.colossalai_optimizer.rst b/docs/colossalai/colossalai.nn.optimizer.colossalai_optimizer.rst new file mode 100644 index 0000000000000000000000000000000000000000..35515c374f3360452f74ed95a4750f255ef4ba56 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.colossalai_optimizer.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.colossalai\_optimizer +============================================= + +.. automodule:: colossalai.nn.optimizer.colossalai_optimizer + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.cpu_adam.rst b/docs/colossalai/colossalai.nn.optimizer.cpu_adam.rst new file mode 100644 index 0000000000000000000000000000000000000000..224dfab43ed0dfb0b4b17d93ca09eacdf863e818 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.cpu_adam.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.cpu\_adam +================================= + +.. automodule:: colossalai.nn.optimizer.cpu_adam + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.fused_adam.rst b/docs/colossalai/colossalai.nn.optimizer.fused_adam.rst new file mode 100644 index 0000000000000000000000000000000000000000..60af624cb6c12f78d488d651cc612c82bf55ad8c --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.fused_adam.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.fused\_adam +=================================== + +.. automodule:: colossalai.nn.optimizer.fused_adam + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.fused_lamb.rst b/docs/colossalai/colossalai.nn.optimizer.fused_lamb.rst new file mode 100644 index 0000000000000000000000000000000000000000..66c0fa4ca1c7c8880ceb48b564a621c55be4687d --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.fused_lamb.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.fused\_lamb +=================================== + +.. automodule:: colossalai.nn.optimizer.fused_lamb + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.fused_sgd.rst b/docs/colossalai/colossalai.nn.optimizer.fused_sgd.rst new file mode 100644 index 0000000000000000000000000000000000000000..2ecc77c33d88cf11e8075bc04dcc17026eeadc75 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.fused_sgd.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.fused\_sgd +================================== + +.. automodule:: colossalai.nn.optimizer.fused_sgd + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst b/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst new file mode 100644 index 0000000000000000000000000000000000000000..20508d6647017aa4723b6d404e22c64a97f97f3b --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.hybrid\_adam +==================================== + +.. automodule:: colossalai.nn.optimizer.hybrid_adam + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.lamb.rst b/docs/colossalai/colossalai.nn.optimizer.lamb.rst new file mode 100644 index 0000000000000000000000000000000000000000..57199ea3695132e4a6e76b5cf41da81ce7a37bd8 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.lamb.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.lamb +============================ + +.. automodule:: colossalai.nn.optimizer.lamb + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.lars.rst b/docs/colossalai/colossalai.nn.optimizer.lars.rst new file mode 100644 index 0000000000000000000000000000000000000000..f935950f8b5a2a1a8a5a757d7454a875d4db69c6 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.lars.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.lars +============================ + +.. automodule:: colossalai.nn.optimizer.lars + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.rst b/docs/colossalai/colossalai.nn.optimizer.rst new file mode 100644 index 0000000000000000000000000000000000000000..ede9cc496967cde563c6a90189c4715c063e0cb3 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.rst @@ -0,0 +1,19 @@ +colossalai.nn.optimizer +======================= + +.. automodule:: colossalai.nn.optimizer + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.optimizer.colossalai_optimizer + colossalai.nn.optimizer.cpu_adam + colossalai.nn.optimizer.fused_adam + colossalai.nn.optimizer.fused_lamb + colossalai.nn.optimizer.fused_sgd + colossalai.nn.optimizer.hybrid_adam + colossalai.nn.optimizer.lamb + colossalai.nn.optimizer.lars + colossalai.nn.optimizer.utils diff --git a/docs/colossalai/colossalai.nn.optimizer.utils.rst b/docs/colossalai/colossalai.nn.optimizer.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..9b2bc2f016c41196816eee1b21bb3d03858d166c --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.utils.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.utils +============================= + +.. automodule:: colossalai.nn.optimizer.utils + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.data_parallel.rst b/docs/colossalai/colossalai.nn.parallel.data_parallel.rst new file mode 100644 index 0000000000000000000000000000000000000000..ba987c2ee2f35b576bfd1e6b54e8f8f7d6eecedc --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.data_parallel.rst @@ -0,0 +1,5 @@ +colossalai.nn.parallel.data\_parallel +===================================== + +.. automodule:: colossalai.nn.parallel.data_parallel + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.layers.colo_module.rst b/docs/colossalai/colossalai.nn.parallel.layers.colo_module.rst new file mode 100644 index 0000000000000000000000000000000000000000..c80fff6d543af9b8d77711ff9f8a1a63021c730f --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.layers.colo_module.rst @@ -0,0 +1,5 @@ +colossalai.nn.parallel.layers.colo\_module +========================================== + +.. automodule:: colossalai.nn.parallel.layers.colo_module + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.layers.embedding.rst b/docs/colossalai/colossalai.nn.parallel.layers.embedding.rst new file mode 100644 index 0000000000000000000000000000000000000000..1e7ecc50f4786d1355fa5b3e7303fab03e514fa8 --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.layers.embedding.rst @@ -0,0 +1,5 @@ +colossalai.nn.parallel.layers.embedding +======================================= + +.. automodule:: colossalai.nn.parallel.layers.embedding + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.layers.linear.rst b/docs/colossalai/colossalai.nn.parallel.layers.linear.rst new file mode 100644 index 0000000000000000000000000000000000000000..bbc5e32570e701061f0b93c1f32f314718abd07f --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.layers.linear.rst @@ -0,0 +1,5 @@ +colossalai.nn.parallel.layers.linear +==================================== + +.. automodule:: colossalai.nn.parallel.layers.linear + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.layers.module_utils.rst b/docs/colossalai/colossalai.nn.parallel.layers.module_utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..5190ab40345a07c20cec3bf501c5000d394ad365 --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.layers.module_utils.rst @@ -0,0 +1,5 @@ +colossalai.nn.parallel.layers.module\_utils +=========================================== + +.. automodule:: colossalai.nn.parallel.layers.module_utils + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.layers.rst b/docs/colossalai/colossalai.nn.parallel.layers.rst new file mode 100644 index 0000000000000000000000000000000000000000..782a206e88d5da274175cfe324ed5ed6714603d2 --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.layers.rst @@ -0,0 +1,14 @@ +colossalai.nn.parallel.layers +============================= + +.. automodule:: colossalai.nn.parallel.layers + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.parallel.layers.colo_module + colossalai.nn.parallel.layers.embedding + colossalai.nn.parallel.layers.linear + colossalai.nn.parallel.layers.module_utils diff --git a/docs/colossalai/colossalai.nn.parallel.reducer.rst b/docs/colossalai/colossalai.nn.parallel.reducer.rst new file mode 100644 index 0000000000000000000000000000000000000000..d80841f6916e0ca37f32bac87e4558deb0b77ec0 --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.reducer.rst @@ -0,0 +1,5 @@ +colossalai.nn.parallel.reducer +============================== + +.. automodule:: colossalai.nn.parallel.reducer + :members: diff --git a/docs/colossalai/colossalai.nn.parallel.rst b/docs/colossalai/colossalai.nn.parallel.rst new file mode 100644 index 0000000000000000000000000000000000000000..19e9d1eef19bf8eaf472ec322ce6334794a44db2 --- /dev/null +++ b/docs/colossalai/colossalai.nn.parallel.rst @@ -0,0 +1,17 @@ +colossalai.nn.parallel +====================== + +.. automodule:: colossalai.nn.parallel + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.parallel.layers + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.parallel.data_parallel + colossalai.nn.parallel.reducer diff --git a/docs/colossalai/colossalai.nn.rst b/docs/colossalai/colossalai.nn.rst new file mode 100644 index 0000000000000000000000000000000000000000..7e683952f3dbb0c98944e6f44c963bacf3424bd7 --- /dev/null +++ b/docs/colossalai/colossalai.nn.rst @@ -0,0 +1,22 @@ +colossalai.nn +============= + +.. automodule:: colossalai.nn + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.graph + colossalai.nn.layer + colossalai.nn.loss + colossalai.nn.lr_scheduler + colossalai.nn.metric + colossalai.nn.optimizer + colossalai.nn.parallel + + +.. toctree:: + :maxdepth: 2 + + colossalai.nn.init diff --git a/docs/colossalai/colossalai.pipeline.layer_sepc.rst b/docs/colossalai/colossalai.pipeline.layer_sepc.rst new file mode 100644 index 0000000000000000000000000000000000000000..156660b5c00fc0aaf684572df332e1093487a049 --- /dev/null +++ b/docs/colossalai/colossalai.pipeline.layer_sepc.rst @@ -0,0 +1,5 @@ +colossalai.pipeline.layer\_sepc +=============================== + +.. automodule:: colossalai.pipeline.layer_spec + :members: diff --git a/docs/colossalai/colossalai.pipeline.pipelinable.rst b/docs/colossalai/colossalai.pipeline.pipelinable.rst new file mode 100644 index 0000000000000000000000000000000000000000..5c2b02ba63e2a8f9176b115b09a7b1c9375c2d0f --- /dev/null +++ b/docs/colossalai/colossalai.pipeline.pipelinable.rst @@ -0,0 +1,5 @@ +colossalai.pipeline.pipelinable +=============================== + +.. automodule:: colossalai.pipeline.pipelinable + :members: diff --git a/docs/colossalai/colossalai.pipeline.rst b/docs/colossalai/colossalai.pipeline.rst new file mode 100644 index 0000000000000000000000000000000000000000..6f7652d492e074ba0ca71be52412e264e9a5a031 --- /dev/null +++ b/docs/colossalai/colossalai.pipeline.rst @@ -0,0 +1,13 @@ +colossalai.pipeline +=================== + +.. automodule:: colossalai.pipeline + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.pipeline.layer_spec + colossalai.pipeline.pipelinable + colossalai.pipeline.utils diff --git a/docs/colossalai/colossalai.pipeline.utils.rst b/docs/colossalai/colossalai.pipeline.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..a33bf42cfc2b52b7685b788ef7c7a7bbf3c0982d --- /dev/null +++ b/docs/colossalai/colossalai.pipeline.utils.rst @@ -0,0 +1,5 @@ +colossalai.pipeline.utils +========================= + +.. automodule:: colossalai.pipeline.utils + :members: diff --git a/docs/colossalai/colossalai.registry.registry.rst b/docs/colossalai/colossalai.registry.registry.rst new file mode 100644 index 0000000000000000000000000000000000000000..e942d7969b60beb309f38e3ea1b5e82614941a4c --- /dev/null +++ b/docs/colossalai/colossalai.registry.registry.rst @@ -0,0 +1,5 @@ +colossalai.registry.registry +============================ + +.. automodule:: colossalai.registry.registry + :members: diff --git a/docs/colossalai/colossalai.registry.rst b/docs/colossalai/colossalai.registry.rst new file mode 100644 index 0000000000000000000000000000000000000000..0f294f6d15a7285709b69e6c3cddaa2cc2e47833 --- /dev/null +++ b/docs/colossalai/colossalai.registry.rst @@ -0,0 +1,11 @@ +colossalai.registry +=================== + +.. automodule:: colossalai.registry + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.registry.registry diff --git a/docs/colossalai/colossalai.rst b/docs/colossalai/colossalai.rst new file mode 100644 index 0000000000000000000000000000000000000000..921f15a97f0021a7ac80a48ec4b94fe4a3d27e91 --- /dev/null +++ b/docs/colossalai/colossalai.rst @@ -0,0 +1,36 @@ +colossalai +========== + +.. automodule:: colossalai + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.amp + colossalai.builder + colossalai.cli + colossalai.communication + colossalai.context + colossalai.engine + colossalai.fx + colossalai.gemini + colossalai.kernel + colossalai.logging + colossalai.nn + colossalai.pipeline + colossalai.registry + colossalai.tensor + colossalai.testing + colossalai.trainer + colossalai.utils + colossalai.zero + + +.. toctree:: + :maxdepth: 2 + + colossalai.constants + colossalai.core + colossalai.global_variables + colossalai.initialize diff --git a/docs/colossalai/colossalai.tensor.colo_parameter.rst b/docs/colossalai/colossalai.tensor.colo_parameter.rst new file mode 100644 index 0000000000000000000000000000000000000000..9b65029dbbe4e7e6c38d29eb0d4fadcbc78cc055 --- /dev/null +++ b/docs/colossalai/colossalai.tensor.colo_parameter.rst @@ -0,0 +1,5 @@ +colossalai.tensor.colo\_parameter +================================= + +.. automodule:: colossalai.tensor.colo_parameter + :members: diff --git a/docs/colossalai/colossalai.tensor.colo_tensor.rst b/docs/colossalai/colossalai.tensor.colo_tensor.rst new file mode 100644 index 0000000000000000000000000000000000000000..9161ac22f665994ff3185f032292b44c696ce32b --- /dev/null +++ b/docs/colossalai/colossalai.tensor.colo_tensor.rst @@ -0,0 +1,5 @@ +colossalai.tensor.colo\_tensor +============================== + +.. automodule:: colossalai.tensor.colo_tensor + :members: diff --git a/docs/colossalai/colossalai.tensor.compute_spec.rst b/docs/colossalai/colossalai.tensor.compute_spec.rst new file mode 100644 index 0000000000000000000000000000000000000000..e2d7235d99c4d7e5fe3e2ba60e74858e0c7b6b5c --- /dev/null +++ b/docs/colossalai/colossalai.tensor.compute_spec.rst @@ -0,0 +1,5 @@ +colossalai.tensor.compute\_spec +=============================== + +.. automodule:: colossalai.tensor.compute_spec + :members: diff --git a/docs/colossalai/colossalai.tensor.const.rst b/docs/colossalai/colossalai.tensor.const.rst new file mode 100644 index 0000000000000000000000000000000000000000..a22a2789349bb0aebca4ca8bd35f019ecb7beb2c --- /dev/null +++ b/docs/colossalai/colossalai.tensor.const.rst @@ -0,0 +1,5 @@ +colossalai.tensor.const +======================= + +.. automodule:: colossalai.tensor.const + :members: diff --git a/docs/colossalai/colossalai.tensor.dist_spec_mgr.rst b/docs/colossalai/colossalai.tensor.dist_spec_mgr.rst new file mode 100644 index 0000000000000000000000000000000000000000..043cf22604a3df351439ccd2c8a875f3b8f41383 --- /dev/null +++ b/docs/colossalai/colossalai.tensor.dist_spec_mgr.rst @@ -0,0 +1,5 @@ +colossalai.tensor.dist\_spec\_mgr +================================= + +.. automodule:: colossalai.tensor.dist_spec_mgr + :members: diff --git a/docs/colossalai/colossalai.tensor.distspec.rst b/docs/colossalai/colossalai.tensor.distspec.rst new file mode 100644 index 0000000000000000000000000000000000000000..2b4b0e5fa266e032a8a61674104008d5ba71d699 --- /dev/null +++ b/docs/colossalai/colossalai.tensor.distspec.rst @@ -0,0 +1,5 @@ +colossalai.tensor.distspec +========================== + +.. automodule:: colossalai.tensor.distspec + :members: diff --git a/docs/colossalai/colossalai.tensor.op_wrapper.rst b/docs/colossalai/colossalai.tensor.op_wrapper.rst new file mode 100644 index 0000000000000000000000000000000000000000..a246e0a6a5489efda5a91d9d145109fff8af294a --- /dev/null +++ b/docs/colossalai/colossalai.tensor.op_wrapper.rst @@ -0,0 +1,5 @@ +colossalai.tensor.op\_wrapper +============================= + +.. automodule:: colossalai.tensor.op_wrapper + :members: diff --git a/docs/colossalai/colossalai.tensor.param_op_hook.rst b/docs/colossalai/colossalai.tensor.param_op_hook.rst new file mode 100644 index 0000000000000000000000000000000000000000..475ada452bb21ea7d5fe5e683a2097631068c14d --- /dev/null +++ b/docs/colossalai/colossalai.tensor.param_op_hook.rst @@ -0,0 +1,5 @@ +colossalai.tensor.param\_op\_hook +================================= + +.. automodule:: colossalai.tensor.param_op_hook + :members: diff --git a/docs/colossalai/colossalai.tensor.process_group.rst b/docs/colossalai/colossalai.tensor.process_group.rst new file mode 100644 index 0000000000000000000000000000000000000000..b71409e3bd11017b0b172c03f438986c5a7cb07f --- /dev/null +++ b/docs/colossalai/colossalai.tensor.process_group.rst @@ -0,0 +1,5 @@ +colossalai.tensor.process\_group +================================ + +.. automodule:: colossalai.tensor.process_group + :members: diff --git a/docs/colossalai/colossalai.tensor.rst b/docs/colossalai/colossalai.tensor.rst new file mode 100644 index 0000000000000000000000000000000000000000..68e06552b873f9b4f90b646ec493f67071ecdb77 --- /dev/null +++ b/docs/colossalai/colossalai.tensor.rst @@ -0,0 +1,21 @@ +colossalai.tensor +================= + +.. automodule:: colossalai.tensor + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.tensor.colo_parameter + colossalai.tensor.colo_tensor + colossalai.tensor.compute_spec + colossalai.tensor.const + colossalai.tensor.dist_spec_mgr + colossalai.tensor.distspec + colossalai.tensor.op_wrapper + colossalai.tensor.param_op_hook + colossalai.tensor.process_group + colossalai.tensor.tensor_spec + colossalai.tensor.utils diff --git a/docs/colossalai/colossalai.tensor.tensor_spec.rst b/docs/colossalai/colossalai.tensor.tensor_spec.rst new file mode 100644 index 0000000000000000000000000000000000000000..7125b9cbc28d3ca6117102c4b4a55e1f57837625 --- /dev/null +++ b/docs/colossalai/colossalai.tensor.tensor_spec.rst @@ -0,0 +1,5 @@ +colossalai.tensor.tensor\_spec +============================== + +.. automodule:: colossalai.tensor.tensor_spec + :members: diff --git a/docs/colossalai/colossalai.tensor.utils.rst b/docs/colossalai/colossalai.tensor.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..5d9bd1b030381807a342477475e758fa886ffd3e --- /dev/null +++ b/docs/colossalai/colossalai.tensor.utils.rst @@ -0,0 +1,5 @@ +colossalai.tensor.utils +======================= + +.. automodule:: colossalai.tensor.utils + :members: diff --git a/docs/colossalai/colossalai.testing.comparison.rst b/docs/colossalai/colossalai.testing.comparison.rst new file mode 100644 index 0000000000000000000000000000000000000000..bcfdf0598856a8b0348f70409ae1109466774275 --- /dev/null +++ b/docs/colossalai/colossalai.testing.comparison.rst @@ -0,0 +1,5 @@ +colossalai.testing.comparison +============================= + +.. automodule:: colossalai.testing.comparison + :members: diff --git a/docs/colossalai/colossalai.testing.rst b/docs/colossalai/colossalai.testing.rst new file mode 100644 index 0000000000000000000000000000000000000000..1127aa52c1add5aa0a4f9d0d61e8e805986fda90 --- /dev/null +++ b/docs/colossalai/colossalai.testing.rst @@ -0,0 +1,12 @@ +colossalai.testing +================== + +.. automodule:: colossalai.testing + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.testing.comparison + colossalai.testing.utils diff --git a/docs/colossalai/colossalai.testing.utils.rst b/docs/colossalai/colossalai.testing.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..d8c2edcce71c5ac29f749d3ccd941d79e6097594 --- /dev/null +++ b/docs/colossalai/colossalai.testing.utils.rst @@ -0,0 +1,5 @@ +colossalai.testing.utils +======================== + +.. automodule:: colossalai.testing.utils + :members: diff --git a/docs/colossalai/colossalai.trainer.hooks.rst b/docs/colossalai/colossalai.trainer.hooks.rst new file mode 100644 index 0000000000000000000000000000000000000000..84cc6797b83138669f216162e7d19ff024a5e21f --- /dev/null +++ b/docs/colossalai/colossalai.trainer.hooks.rst @@ -0,0 +1,5 @@ +colossalai.trainer.hooks +======================== + +.. automodule:: colossalai.trainer.hooks + :members: diff --git a/docs/colossalai/colossalai.trainer.rst b/docs/colossalai/colossalai.trainer.rst new file mode 100644 index 0000000000000000000000000000000000000000..abc636e623737ab86669521c6262c9352e504d2c --- /dev/null +++ b/docs/colossalai/colossalai.trainer.rst @@ -0,0 +1,10 @@ +colossalai.trainer +================== + +.. automodule:: colossalai.trainer + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.trainer.hooks diff --git a/docs/colossalai/colossalai.utils.activation_checkpoint.rst b/docs/colossalai/colossalai.utils.activation_checkpoint.rst new file mode 100644 index 0000000000000000000000000000000000000000..671b5fe9e9c452ea608e4fd9e74b2046a13073d3 --- /dev/null +++ b/docs/colossalai/colossalai.utils.activation_checkpoint.rst @@ -0,0 +1,5 @@ +colossalai.utils.activation\_checkpoint +======================================= + +.. automodule:: colossalai.utils.activation_checkpoint + :members: diff --git a/docs/colossalai/colossalai.utils.checkpoint.module_checkpoint.rst b/docs/colossalai/colossalai.utils.checkpoint.module_checkpoint.rst new file mode 100644 index 0000000000000000000000000000000000000000..237ad380b301910c81cac55a6c7ad2545c7761dd --- /dev/null +++ b/docs/colossalai/colossalai.utils.checkpoint.module_checkpoint.rst @@ -0,0 +1,5 @@ +colossalai.utils.checkpoint.module\_checkpoint +============================================== + +.. automodule:: colossalai.utils.checkpoint.module_checkpoint + :members: diff --git a/docs/colossalai/colossalai.utils.checkpoint.rst b/docs/colossalai/colossalai.utils.checkpoint.rst new file mode 100644 index 0000000000000000000000000000000000000000..220c270f09b9d63c6bde9fd4b8ce4ea691846d4c --- /dev/null +++ b/docs/colossalai/colossalai.utils.checkpoint.rst @@ -0,0 +1,12 @@ +colossalai.utils.checkpoint +=========================== + +.. automodule:: colossalai.utils.checkpoint + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.checkpoint.module_checkpoint + colossalai.utils.checkpoint.utils diff --git a/docs/colossalai/colossalai.utils.checkpoint.utils.rst b/docs/colossalai/colossalai.utils.checkpoint.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..7fdeefd539fefa10303d66afb5a4d05f8850469e --- /dev/null +++ b/docs/colossalai/colossalai.utils.checkpoint.utils.rst @@ -0,0 +1,5 @@ +colossalai.utils.checkpoint.utils +================================= + +.. automodule:: colossalai.utils.checkpoint.utils + :members: diff --git a/docs/colossalai/colossalai.utils.checkpointing.rst b/docs/colossalai/colossalai.utils.checkpointing.rst new file mode 100644 index 0000000000000000000000000000000000000000..534a581d536406a26a288f39d6f761d60c16869f --- /dev/null +++ b/docs/colossalai/colossalai.utils.checkpointing.rst @@ -0,0 +1,5 @@ +colossalai.utils.checkpointing +============================== + +.. automodule:: colossalai.utils.checkpointing + :members: diff --git a/docs/colossalai/colossalai.utils.common.rst b/docs/colossalai/colossalai.utils.common.rst new file mode 100644 index 0000000000000000000000000000000000000000..cb9f9c14ef4fb14cda1058ee9783a970c5365a74 --- /dev/null +++ b/docs/colossalai/colossalai.utils.common.rst @@ -0,0 +1,5 @@ +colossalai.utils.common +======================= + +.. automodule:: colossalai.utils.common + :members: diff --git a/docs/colossalai/colossalai.utils.cuda.rst b/docs/colossalai/colossalai.utils.cuda.rst new file mode 100644 index 0000000000000000000000000000000000000000..ec428c5ef6ea2e3f4fe9b3ce0def3fe2417fd1f3 --- /dev/null +++ b/docs/colossalai/colossalai.utils.cuda.rst @@ -0,0 +1,5 @@ +colossalai.utils.cuda +===================== + +.. automodule:: colossalai.utils.cuda + :members: diff --git a/docs/colossalai/colossalai.utils.data_sampler.base_sampler.rst b/docs/colossalai/colossalai.utils.data_sampler.base_sampler.rst new file mode 100644 index 0000000000000000000000000000000000000000..199e8fcf83c35c9303baad559a0e10da27197d52 --- /dev/null +++ b/docs/colossalai/colossalai.utils.data_sampler.base_sampler.rst @@ -0,0 +1,5 @@ +colossalai.utils.data\_sampler.base\_sampler +============================================ + +.. automodule:: colossalai.utils.data_sampler.base_sampler + :members: diff --git a/docs/colossalai/colossalai.utils.data_sampler.data_parallel_sampler.rst b/docs/colossalai/colossalai.utils.data_sampler.data_parallel_sampler.rst new file mode 100644 index 0000000000000000000000000000000000000000..85e1b121c682310dc8f9930df90f06e1ed32ae80 --- /dev/null +++ b/docs/colossalai/colossalai.utils.data_sampler.data_parallel_sampler.rst @@ -0,0 +1,5 @@ +colossalai.utils.data\_sampler.data\_parallel\_sampler +====================================================== + +.. automodule:: colossalai.utils.data_sampler.data_parallel_sampler + :members: diff --git a/docs/colossalai/colossalai.utils.data_sampler.rst b/docs/colossalai/colossalai.utils.data_sampler.rst new file mode 100644 index 0000000000000000000000000000000000000000..61dde070bad445a582bf9e198402b76b1768623d --- /dev/null +++ b/docs/colossalai/colossalai.utils.data_sampler.rst @@ -0,0 +1,12 @@ +colossalai.utils.data\_sampler +============================== + +.. automodule:: colossalai.utils.data_sampler + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.data_sampler.base_sampler + colossalai.utils.data_sampler.data_parallel_sampler diff --git a/docs/colossalai/colossalai.utils.memory.rst b/docs/colossalai/colossalai.utils.memory.rst new file mode 100644 index 0000000000000000000000000000000000000000..67c5d60022dddf5293b66ea048cdc13b6bc6bdaa --- /dev/null +++ b/docs/colossalai/colossalai.utils.memory.rst @@ -0,0 +1,5 @@ +colossalai.utils.memory +======================= + +.. automodule:: colossalai.utils.memory + :members: diff --git a/docs/colossalai/colossalai.utils.model.colo_init_context.rst b/docs/colossalai/colossalai.utils.model.colo_init_context.rst new file mode 100644 index 0000000000000000000000000000000000000000..33ee449150835a4a453715d58db80cfe3340f664 --- /dev/null +++ b/docs/colossalai/colossalai.utils.model.colo_init_context.rst @@ -0,0 +1,5 @@ +colossalai.utils.model.colo\_init\_context +========================================== + +.. automodule:: colossalai.utils.model.colo_init_context + :members: diff --git a/docs/colossalai/colossalai.utils.model.lazy_init_context.rst b/docs/colossalai/colossalai.utils.model.lazy_init_context.rst new file mode 100644 index 0000000000000000000000000000000000000000..27c9a32c6a7d7bac9055e86b7c7b093388a4e5fa --- /dev/null +++ b/docs/colossalai/colossalai.utils.model.lazy_init_context.rst @@ -0,0 +1,5 @@ +colossalai.utils.model.lazy\_init\_context +========================================== + +.. automodule:: colossalai.utils.model.lazy_init_context + :members: diff --git a/docs/colossalai/colossalai.utils.model.rst b/docs/colossalai/colossalai.utils.model.rst new file mode 100644 index 0000000000000000000000000000000000000000..9adfd1450a47a7e3741b182331d704fb6cad8764 --- /dev/null +++ b/docs/colossalai/colossalai.utils.model.rst @@ -0,0 +1,13 @@ +colossalai.utils.model +====================== + +.. automodule:: colossalai.utils.model + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.model.colo_init_context + colossalai.utils.model.lazy_init_context + colossalai.utils.model.utils diff --git a/docs/colossalai/colossalai.utils.model.utils.rst b/docs/colossalai/colossalai.utils.model.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..211106662dc33289d63ff1a86222d87dd488f77d --- /dev/null +++ b/docs/colossalai/colossalai.utils.model.utils.rst @@ -0,0 +1,5 @@ +colossalai.utils.model.utils +============================ + +.. automodule:: colossalai.utils.model.utils + :members: diff --git a/docs/colossalai/colossalai.utils.moe.rst b/docs/colossalai/colossalai.utils.moe.rst new file mode 100644 index 0000000000000000000000000000000000000000..b66ccdc8ec2dfaa3d713179ec2d8d2ddd4b58d34 --- /dev/null +++ b/docs/colossalai/colossalai.utils.moe.rst @@ -0,0 +1,5 @@ +colossalai.utils.moe +==================== + +.. automodule:: colossalai.utils.moe + :members: diff --git a/docs/colossalai/colossalai.utils.multi_tensor_apply.multi_tensor_apply.rst b/docs/colossalai/colossalai.utils.multi_tensor_apply.multi_tensor_apply.rst new file mode 100644 index 0000000000000000000000000000000000000000..493b9530e0f614409ce33c3a3c6f013c261e546b --- /dev/null +++ b/docs/colossalai/colossalai.utils.multi_tensor_apply.multi_tensor_apply.rst @@ -0,0 +1,5 @@ +colossalai.utils.multi\_tensor\_apply.multi\_tensor\_apply +========================================================== + +.. automodule:: colossalai.utils.multi_tensor_apply.multi_tensor_apply + :members: diff --git a/docs/colossalai/colossalai.utils.multi_tensor_apply.rst b/docs/colossalai/colossalai.utils.multi_tensor_apply.rst new file mode 100644 index 0000000000000000000000000000000000000000..d5749cfa8801c4ad5f38b6037e1621e6ea011ab8 --- /dev/null +++ b/docs/colossalai/colossalai.utils.multi_tensor_apply.rst @@ -0,0 +1,11 @@ +colossalai.utils.multi\_tensor\_apply +===================================== + +.. automodule:: colossalai.utils.multi_tensor_apply + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.multi_tensor_apply.multi_tensor_apply diff --git a/docs/colossalai/colossalai.utils.profiler.extention.rst b/docs/colossalai/colossalai.utils.profiler.extention.rst new file mode 100644 index 0000000000000000000000000000000000000000..5c87692611a06f002543418ef726948f411674ed --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.extention.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.extention +=================================== + +.. automodule:: colossalai.utils.profiler.extention + :members: diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.comm_profiler.rst b/docs/colossalai/colossalai.utils.profiler.legacy.comm_profiler.rst new file mode 100644 index 0000000000000000000000000000000000000000..4329a3d60da31e4cd7c1888a05a2680cecb2e3f4 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.legacy.comm_profiler.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.legacy.comm\_profiler +=============================================== + +.. automodule:: colossalai.utils.profiler.legacy.comm_profiler + :members: diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.mem_profiler.rst b/docs/colossalai/colossalai.utils.profiler.legacy.mem_profiler.rst new file mode 100644 index 0000000000000000000000000000000000000000..35c665c71d3b5d7a7c86c5e915ed305513866d84 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.legacy.mem_profiler.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.legacy.mem\_profiler +============================================== + +.. automodule:: colossalai.utils.profiler.legacy.mem_profiler + :members: diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.pcie_profiler.rst b/docs/colossalai/colossalai.utils.profiler.legacy.pcie_profiler.rst new file mode 100644 index 0000000000000000000000000000000000000000..7aa82b8f7a4f4e6ddad11323beed57866be6c1a9 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.legacy.pcie_profiler.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.legacy.pcie\_profiler +=============================================== + +.. automodule:: colossalai.utils.profiler.legacy.pcie_profiler + :members: diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.prof_utils.rst b/docs/colossalai/colossalai.utils.profiler.legacy.prof_utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..93af82b2fabbe5dc018677966fa9be782351def1 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.legacy.prof_utils.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.legacy.prof\_utils +============================================ + +.. automodule:: colossalai.utils.profiler.legacy.prof_utils + :members: diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.rst b/docs/colossalai/colossalai.utils.profiler.legacy.rst new file mode 100644 index 0000000000000000000000000000000000000000..37fcebde5a43d066e8e1186abf919a2b07a10283 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.legacy.rst @@ -0,0 +1,14 @@ +colossalai.utils.profiler.legacy +================================ + +.. automodule:: colossalai.utils.profiler.legacy + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.profiler.legacy.comm_profiler + colossalai.utils.profiler.legacy.mem_profiler + colossalai.utils.profiler.legacy.pcie_profiler + colossalai.utils.profiler.legacy.prof_utils diff --git a/docs/colossalai/colossalai.utils.profiler.profiler.rst b/docs/colossalai/colossalai.utils.profiler.profiler.rst new file mode 100644 index 0000000000000000000000000000000000000000..d35522837801852e914455219cc96cb5a78bbce4 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.profiler.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.profiler +================================== + +.. automodule:: colossalai.utils.profiler.profiler + :members: diff --git a/docs/colossalai/colossalai.utils.profiler.rst b/docs/colossalai/colossalai.utils.profiler.rst new file mode 100644 index 0000000000000000000000000000000000000000..15681fcf2d82932a4f9419cd1aae37ee2921647f --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.rst @@ -0,0 +1,18 @@ +colossalai.utils.profiler +========================= + +.. automodule:: colossalai.utils.profiler + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.profiler.legacy + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.profiler.extention + colossalai.utils.profiler.profiler + colossalai.utils.profiler.stateful_tensor_mem_extention diff --git a/docs/colossalai/colossalai.utils.profiler.stateful_tensor_mem_extention.rst b/docs/colossalai/colossalai.utils.profiler.stateful_tensor_mem_extention.rst new file mode 100644 index 0000000000000000000000000000000000000000..72a3fcceca1898230793758670fd5c0bb6e64fb0 --- /dev/null +++ b/docs/colossalai/colossalai.utils.profiler.stateful_tensor_mem_extention.rst @@ -0,0 +1,5 @@ +colossalai.utils.profiler.stateful\_tensor\_mem\_extention +========================================================== + +.. automodule:: colossalai.utils.profiler.stateful_tensor_mem_extention + :members: diff --git a/docs/colossalai/colossalai.utils.rst b/docs/colossalai/colossalai.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b232a12c24566d30f329fe2f23af5996c5125e4 --- /dev/null +++ b/docs/colossalai/colossalai.utils.rst @@ -0,0 +1,27 @@ +colossalai.utils +================ + +.. automodule:: colossalai.utils + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.checkpoint + colossalai.utils.data_sampler + colossalai.utils.model + colossalai.utils.multi_tensor_apply + colossalai.utils.profiler + colossalai.utils.tensor_detector + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.activation_checkpoint + colossalai.utils.checkpointing + colossalai.utils.common + colossalai.utils.cuda + colossalai.utils.memory + colossalai.utils.moe + colossalai.utils.timer diff --git a/docs/colossalai/colossalai.utils.tensor_detector.rst b/docs/colossalai/colossalai.utils.tensor_detector.rst new file mode 100644 index 0000000000000000000000000000000000000000..807d67e3ad1e5192414ee2c3cb129b59a209b148 --- /dev/null +++ b/docs/colossalai/colossalai.utils.tensor_detector.rst @@ -0,0 +1,11 @@ +colossalai.utils.tensor\_detector +================================= + +.. automodule:: colossalai.utils.tensor_detector + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.utils.tensor_detector.tensor_detector diff --git a/docs/colossalai/colossalai.utils.tensor_detector.tensor_detector.rst b/docs/colossalai/colossalai.utils.tensor_detector.tensor_detector.rst new file mode 100644 index 0000000000000000000000000000000000000000..991cea3438b3f350c7c8678be408da2be367e02b --- /dev/null +++ b/docs/colossalai/colossalai.utils.tensor_detector.tensor_detector.rst @@ -0,0 +1,5 @@ +colossalai.utils.tensor\_detector.tensor\_detector +================================================== + +.. automodule:: colossalai.utils.tensor_detector.tensor_detector + :members: diff --git a/docs/colossalai/colossalai.utils.timer.rst b/docs/colossalai/colossalai.utils.timer.rst new file mode 100644 index 0000000000000000000000000000000000000000..2014c85f548f6e6d4211bba74b95656d2fe30ef8 --- /dev/null +++ b/docs/colossalai/colossalai.utils.timer.rst @@ -0,0 +1,5 @@ +colossalai.utils.timer +====================== + +.. automodule:: colossalai.utils.timer + :members: diff --git a/docs/colossalai/colossalai.zero.init_ctx.init_context.rst b/docs/colossalai/colossalai.zero.init_ctx.init_context.rst new file mode 100644 index 0000000000000000000000000000000000000000..1694074e83bff8c53952dc356603b7886da89d8a --- /dev/null +++ b/docs/colossalai/colossalai.zero.init_ctx.init_context.rst @@ -0,0 +1,5 @@ +colossalai.zero.init\_ctx.init\_context +======================================= + +.. automodule:: colossalai.zero.init_ctx.init_context + :members: diff --git a/docs/colossalai/colossalai.zero.init_ctx.rst b/docs/colossalai/colossalai.zero.init_ctx.rst new file mode 100644 index 0000000000000000000000000000000000000000..88cf471df9d30ebc2c9ed881f3da582259714348 --- /dev/null +++ b/docs/colossalai/colossalai.zero.init_ctx.rst @@ -0,0 +1,11 @@ +colossalai.zero.init\_ctx +========================= + +.. automodule:: colossalai.zero.init_ctx + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.init_ctx.init_context diff --git a/docs/colossalai/colossalai.zero.rst b/docs/colossalai/colossalai.zero.rst new file mode 100644 index 0000000000000000000000000000000000000000..3bcaffd28d052585db38740b432e6f1e296ac118 --- /dev/null +++ b/docs/colossalai/colossalai.zero.rst @@ -0,0 +1,21 @@ +colossalai.zero +=============== + +.. automodule:: colossalai.zero + :members: + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.init_ctx + colossalai.zero.shard_utils + colossalai.zero.sharded_model + colossalai.zero.sharded_optim + colossalai.zero.sharded_param + colossalai.zero.utils + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.zero_optimizer diff --git a/docs/colossalai/colossalai.zero.shard_utils.base_shard_strategy.rst b/docs/colossalai/colossalai.zero.shard_utils.base_shard_strategy.rst new file mode 100644 index 0000000000000000000000000000000000000000..d5b59e06a517973b4817abd86d8c8eed4abfc848 --- /dev/null +++ b/docs/colossalai/colossalai.zero.shard_utils.base_shard_strategy.rst @@ -0,0 +1,5 @@ +colossalai.zero.shard\_utils.base\_shard\_strategy +================================================== + +.. automodule:: colossalai.zero.shard_utils.base_shard_strategy + :members: diff --git a/docs/colossalai/colossalai.zero.shard_utils.bucket_tensor_shard_strategy.rst b/docs/colossalai/colossalai.zero.shard_utils.bucket_tensor_shard_strategy.rst new file mode 100644 index 0000000000000000000000000000000000000000..952c5bbddf096a7c067f7a4bc2df06bade2e7880 --- /dev/null +++ b/docs/colossalai/colossalai.zero.shard_utils.bucket_tensor_shard_strategy.rst @@ -0,0 +1,5 @@ +colossalai.zero.shard\_utils.bucket\_tensor\_shard\_strategy +============================================================ + +.. automodule:: colossalai.zero.shard_utils.bucket_tensor_shard_strategy + :members: diff --git a/docs/colossalai/colossalai.zero.shard_utils.commons.rst b/docs/colossalai/colossalai.zero.shard_utils.commons.rst new file mode 100644 index 0000000000000000000000000000000000000000..aa6682d79ff2461217d4daa47e7684d7e5944de5 --- /dev/null +++ b/docs/colossalai/colossalai.zero.shard_utils.commons.rst @@ -0,0 +1,5 @@ +colossalai.zero.shard\_utils.commons +==================================== + +.. automodule:: colossalai.zero.shard_utils.commons + :members: diff --git a/docs/colossalai/colossalai.zero.shard_utils.rst b/docs/colossalai/colossalai.zero.shard_utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..580bfdab7d852fdd875c7807c6b14cd62e78e0ac --- /dev/null +++ b/docs/colossalai/colossalai.zero.shard_utils.rst @@ -0,0 +1,14 @@ +colossalai.zero.shard\_utils +============================ + +.. automodule:: colossalai.zero.shard_utils + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.shard_utils.base_shard_strategy + colossalai.zero.shard_utils.bucket_tensor_shard_strategy + colossalai.zero.shard_utils.commons + colossalai.zero.shard_utils.tensor_shard_strategy diff --git a/docs/colossalai/colossalai.zero.shard_utils.tensor_shard_strategy.rst b/docs/colossalai/colossalai.zero.shard_utils.tensor_shard_strategy.rst new file mode 100644 index 0000000000000000000000000000000000000000..571b7bd7a588c59f8b05ccc5f25f61e08eca104e --- /dev/null +++ b/docs/colossalai/colossalai.zero.shard_utils.tensor_shard_strategy.rst @@ -0,0 +1,5 @@ +colossalai.zero.shard\_utils.tensor\_shard\_strategy +==================================================== + +.. automodule:: colossalai.zero.shard_utils.tensor_shard_strategy + :members: diff --git a/docs/colossalai/colossalai.zero.sharded_model.reduce_scatter.rst b/docs/colossalai/colossalai.zero.sharded_model.reduce_scatter.rst new file mode 100644 index 0000000000000000000000000000000000000000..cf861ee70aa01f209ca1fe9a322450ae7e8a76ec --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_model.reduce_scatter.rst @@ -0,0 +1,5 @@ +colossalai.zero.sharded\_model.reduce\_scatter +============================================== + +.. automodule:: colossalai.zero.sharded_model.reduce_scatter + :members: diff --git a/docs/colossalai/colossalai.zero.sharded_model.rst b/docs/colossalai/colossalai.zero.sharded_model.rst new file mode 100644 index 0000000000000000000000000000000000000000..fb3f5a8456d0c1504797989b18b580032a40ffc4 --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_model.rst @@ -0,0 +1,13 @@ +colossalai.zero.sharded\_model +============================== + +.. automodule:: colossalai.zero.sharded_model + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.sharded_model.reduce_scatter + colossalai.zero.sharded_model.sharded_model_v2 + colossalai.zero.sharded_model.utils diff --git a/docs/colossalai/colossalai.zero.sharded_model.sharded_model_v2.rst b/docs/colossalai/colossalai.zero.sharded_model.sharded_model_v2.rst new file mode 100644 index 0000000000000000000000000000000000000000..a0e191377914900b69d96f64e2fc42712e69c4c7 --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_model.sharded_model_v2.rst @@ -0,0 +1,5 @@ +colossalai.zero.sharded\_model.sharded\_model\_v2 +================================================= + +.. automodule:: colossalai.zero.sharded_model.sharded_model_v2 + :members: diff --git a/docs/colossalai/colossalai.zero.sharded_model.utils.rst b/docs/colossalai/colossalai.zero.sharded_model.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..5e376774296f2823d04711f490fd641e0b327ff5 --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_model.utils.rst @@ -0,0 +1,5 @@ +colossalai.zero.sharded\_model.utils +==================================== + +.. automodule:: colossalai.zero.sharded_model.utils + :members: diff --git a/docs/colossalai/colossalai.zero.sharded_optim.rst b/docs/colossalai/colossalai.zero.sharded_optim.rst new file mode 100644 index 0000000000000000000000000000000000000000..db3dfdddbab417a9ffa9b06f7f986ee4428c5bec --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_optim.rst @@ -0,0 +1,11 @@ +colossalai.zero.sharded\_optim +============================== + +.. automodule:: colossalai.zero.sharded_optim + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.sharded_optim.sharded_optim_v2 diff --git a/docs/colossalai/colossalai.zero.sharded_optim.sharded_optim_v2.rst b/docs/colossalai/colossalai.zero.sharded_optim.sharded_optim_v2.rst new file mode 100644 index 0000000000000000000000000000000000000000..01fbe0c4c031bb48992d6f25f775927c2b6d4e0d --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_optim.sharded_optim_v2.rst @@ -0,0 +1,5 @@ +colossalai.zero.sharded\_optim.sharded\_optim\_v2 +================================================= + +.. automodule:: colossalai.zero.sharded_optim.sharded_optim_v2 + :members: diff --git a/docs/colossalai/colossalai.zero.sharded_param.rst b/docs/colossalai/colossalai.zero.sharded_param.rst new file mode 100644 index 0000000000000000000000000000000000000000..02e0fc6c29eb21069cabedb8f020fee0181fd929 --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_param.rst @@ -0,0 +1,12 @@ +colossalai.zero.sharded\_param +============================== + +.. automodule:: colossalai.zero.sharded_param + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.sharded_param.sharded_param + colossalai.zero.sharded_param.sharded_tensor diff --git a/docs/colossalai/colossalai.zero.sharded_param.sharded_param.rst b/docs/colossalai/colossalai.zero.sharded_param.sharded_param.rst new file mode 100644 index 0000000000000000000000000000000000000000..efa2f0de379c7b5fd0c8ad3a38182d57251c11ba --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_param.sharded_param.rst @@ -0,0 +1,5 @@ +colossalai.zero.sharded\_param.sharded\_param +============================================= + +.. automodule:: colossalai.zero.sharded_param.sharded_param + :members: diff --git a/docs/colossalai/colossalai.zero.sharded_param.sharded_tensor.rst b/docs/colossalai/colossalai.zero.sharded_param.sharded_tensor.rst new file mode 100644 index 0000000000000000000000000000000000000000..930c28de45422fb8959e943e08042c7ef7463193 --- /dev/null +++ b/docs/colossalai/colossalai.zero.sharded_param.sharded_tensor.rst @@ -0,0 +1,5 @@ +colossalai.zero.sharded\_param.sharded\_tensor +============================================== + +.. automodule:: colossalai.zero.sharded_param.sharded_tensor + :members: diff --git a/docs/colossalai/colossalai.zero.utils.rst b/docs/colossalai/colossalai.zero.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..50ee9071e7d5aa0d9680af13795351b64cfe7cff --- /dev/null +++ b/docs/colossalai/colossalai.zero.utils.rst @@ -0,0 +1,12 @@ +colossalai.zero.utils +===================== + +.. automodule:: colossalai.zero.utils + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.zero.utils.zero_hook + colossalai.zero.utils.gemini_hook diff --git a/docs/colossalai/colossalai.zero.utils.zero_hook.rst b/docs/colossalai/colossalai.zero.utils.zero_hook.rst new file mode 100644 index 0000000000000000000000000000000000000000..424f466dd4f5a68990e13fd8a97620b2ca8ba8d9 --- /dev/null +++ b/docs/colossalai/colossalai.zero.utils.zero_hook.rst @@ -0,0 +1,5 @@ +colossalai.zero.utils.zero\_hook +================================ + +.. automodule:: colossalai.zero.utils.zero_hook + :members: diff --git a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst new file mode 100644 index 0000000000000000000000000000000000000000..e6d6673af13111035f9acf5b2641f1c4a8101aaf --- /dev/null +++ b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst @@ -0,0 +1,5 @@ +colossalai.zero.utils.zero\_hook\_v2 +==================================== + +.. automodule:: colossalai.zero.utils.gemini_hook + :members: diff --git a/docs/colossalai/colossalai.zero.zero_optimizer.rst b/docs/colossalai/colossalai.zero.zero_optimizer.rst new file mode 100644 index 0000000000000000000000000000000000000000..b945b081c866639d53cdeefc66fb8feda50bb4ee --- /dev/null +++ b/docs/colossalai/colossalai.zero.zero_optimizer.rst @@ -0,0 +1,5 @@ +colossalai.zero.zero\_optimizer +=============================== + +.. automodule:: colossalai.zero.zero_optimizer + :members: diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..893644f709d4680b5cc6e417c1aa9df46688dc4b --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,135 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +import datetime +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) + +# -- Project information ----------------------------------------------------- + +project = 'Colossal-AI' +copyright = f'{datetime.datetime.now().year}, HPC-AI Tech' +author = 'HPC-AI Technology Inc.' + +# The full version, including alpha/beta/rc tags +release = '0.0.1' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.linkcode', + 'myst_parser', +] + +# Disable docstring inheritance +autodoc_inherit_docstrings = False + +# Disable displaying type annotations, these can be very verbose +autodoc_typehints = 'none' + +# Enable overriding of function signatures in the first line of the docstring. +autodoc_docstring_signature = True +autodoc_default_options = { + 'member-order': 'bysource', +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store'] + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_show_sourcelink = False +html_theme_options = { + 'navigation_depth': 3, +} + +html_context = { + 'display_github': False, + 'github_user': 'hpcaitech', + 'github_repo': 'ColossalAI', + # 'github_version': 'master/docs/', +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +html_css_files = [ + 'css/rtd_theme.css', +] + +# -- Extension configuration ------------------------------------------------- +source_suffix = ['.rst', '.md', '.MD'] + +import inspect +import colossalai +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + """ + if domain != 'py': + return None + + modname = info['module'] + fullname = info['fullname'] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split('.'): + try: + obj = getattr(obj, part) + except Exception: + return None + + try: + fn = inspect.getsourcefile(obj) + except Exception: + fn = None + if not fn: + return None + + try: + source, lineno = inspect.findsource(obj) + except Exception: + lineno = None + + if lineno: + linespec = "#L%d" % (lineno + 1) + else: + linespec = "" + + fn = os.path.relpath(fn, start=os.path.dirname(colossalai.__file__)) + + github = "https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/{}{}" + return github.format(fn, linespec) diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..f275f7829403d107ff2e51340a72f9af6524506a --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,27 @@ +.. Colossal-AI documentation master file, created by + sphinx-quickstart on Mon Oct 11 17:05:05 2021. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Colossal-AI API documentation +====================================== + +.. toctree:: + :maxdepth: 2 + :caption: API REFERENCE + + colossalai/colossalai + +.. toctree:: + :maxdepth: 2 + :caption: Useful links for Colossal-AI + + links/Colossalai examples + links/Colossalai benchmarks + links/Colossalai tutorial + + +Indices and tables +-------------------- + +* :ref:`genindex` diff --git a/docs/links/Colossalai Homepage.rst b/docs/links/Colossalai Homepage.rst new file mode 100644 index 0000000000000000000000000000000000000000..38e223bd22c97daa3950a19d5b48773aa4777530 --- /dev/null +++ b/docs/links/Colossalai Homepage.rst @@ -0,0 +1,6 @@ +Colossal-AI Github Homepage +================================== + +*If you are looking for the Git homepage of Colossal-AI, please check* +`Colossal-AI Tutorial `_ +*for our source code.* \ No newline at end of file diff --git a/docs/links/Colossalai benchmarks.rst b/docs/links/Colossalai benchmarks.rst new file mode 100644 index 0000000000000000000000000000000000000000..1835670a5f2a5ace6636058a74e2472492cd9095 --- /dev/null +++ b/docs/links/Colossalai benchmarks.rst @@ -0,0 +1,6 @@ +Colossal-AI Benchmarks +================================== + +*If you are interested in the performance or the features of Colossal-AI, please check* +`Colossal-AI Benchmark `_. +*to get more details about our performance on CIFAR10, ImageNet1K or GPT2 ZeRO.* \ No newline at end of file diff --git a/docs/links/Colossalai examples.rst b/docs/links/Colossalai examples.rst new file mode 100644 index 0000000000000000000000000000000000000000..c375f007a3ffcefb009b0103ca64a4bc1c76048c --- /dev/null +++ b/docs/links/Colossalai examples.rst @@ -0,0 +1,6 @@ +Colossal-AI Examples +================================== + +*If you are looking for the example code of using Colossal-AI in CV or NLP, please check* +`Colossal-AI Example `_ +*to get more details about using colossalai in Resnet, Moe, Vit, Bert and GPT* \ No newline at end of file diff --git a/docs/links/Colossalai tutorial.rst b/docs/links/Colossalai tutorial.rst new file mode 100644 index 0000000000000000000000000000000000000000..a4ab7f5b906b10e1c3b4fc06cfcd872380bc925b --- /dev/null +++ b/docs/links/Colossalai tutorial.rst @@ -0,0 +1,7 @@ +Colossal-AI Tutorial +================================== + +*If you are looking for the tutorial of using Colossal-AI, please check* +`Colossal-AI Tutorial `_ +*to get more details about getting started, using TP (tensor parallel), PP (pipeline parallel) +and training with colossalai trainer or engine.* \ No newline at end of file diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..7c6fd5d32752ded5d3b15fa8b988a428d38c2fe2 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=.build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2b3b1a25bca443a0ed66408827227eb0ace1d785 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,5 @@ +tensorboard +apex +sphinx +sphinx-rtd-theme +myst-parser diff --git a/examples/images/diffusion/LICENSE b/examples/images/diffusion/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0e609df0d8cd3b5d11a1ea962a56b604b70846a5 --- /dev/null +++ b/examples/images/diffusion/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensorsโ€™ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individualโ€™s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa8cd28c25f325c850a9d7b1bce4bc3a06ea6f4e --- /dev/null +++ b/examples/images/diffusion/README.md @@ -0,0 +1,195 @@ +# ColoDiffusion: Stable Diffusion with Colossal-AI + +*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and +fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).* + +We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies +, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs. + +## Stable Diffusion + +[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion +model. +Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. +Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487), +this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. + +

+ +

+ +[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB). + +

+ +

+ +## Requirements + +A suitable [conda](https://conda.io/) environment named `ldm` can be created +and activated with: + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running + +``` +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +pip install transformers==4.19.2 diffusers invisible-watermark +pip install -e . +``` + +### install lightning + +``` +git clone https://github.com/1SAA/lightning.git +git checkout strategy/colossalai +export PACKAGE_NAME=pytorch +pip install . +``` + +### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website + +``` +pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org +``` + +> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future. + +## Download the model checkpoint from pretrained + +### stable-diffusion-v1-4 + +Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style) + +``` +git lfs install +git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +``` + +### stable-diffusion-v1-5 from runway + +If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml + +``` +git lfs install +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 +``` + +## Dataset + +The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), +you should the change the `data.file_path` in the `config/train_colossalai.yaml` + +## Training + +We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml` + +For example, you can run the training from colossalai by +``` +python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml +``` + +- you can change the `--logdir` the save the log information and the last checkpoint + +### Training config + +You can change the trainging config in the yaml file + +- accelerator: acceleratortype, default 'gpu' +- devices: device number used for training, default 4 +- max_epochs: max training epochs +- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai + +## Finetone Example +### Training on Teyvat Datasets + +We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions. + +You can run by config `configs/Teyvat/train_colossalai_teyvat.yaml` +``` +python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml +``` + +## Inference +you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by +``` +python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms + --outdir ./output \ + --config path/to/logdir/checkpoints/last.ckpt \ + --ckpt /path/to/logdir/configs/project.yaml \ +``` + +```commandline +usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] + [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] + [--seed SEED] [--precision {full,autocast}] + +optional arguments: + -h, --help show this help message and exit + --prompt [PROMPT] the prompt to render + --outdir [OUTDIR] dir to write results to + --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples + --skip_save do not save individual samples. For speed measurements. + --ddim_steps DDIM_STEPS + number of ddim sampling steps + --plms use plms sampling + --laion400m uses the LAION400M model + --fixed_code if enabled, uses the same starting code across samples + --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling + --n_iter N_ITER sample this often + --H H image height, in pixel space + --W W image width, in pixel space + --C C latent channels + --f F downsampling factor + --n_samples N_SAMPLES + how many samples to produce for each given prompt. A.k.a. batch size + --n_rows N_ROWS rows in the grid (default: n_samples) + --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) + --from-file FROM_FILE + if specified, load prompts from this file + --config CONFIG path to config which constructs model + --ckpt CKPT path to checkpoint of model + --seed SEED the seed (for reproducible sampling) + --precision {full,autocast} + evaluate at this precision +``` + +## Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch), +[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + +- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch). + +## BibTeX + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Bjรถrn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +@article{dao2022flashattention, + title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} +``` diff --git a/examples/images/diffusion/configs/Inference/v2-inference-v.yaml b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ec8dfbfefe94ae8522c93017668fea78d580acf --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..152c4f3c2b36c3b246a9cb10eb8166134b0d2e1c --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32a9471d71b828c51bcbbabfe34c5f6c8282c803 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..531199de4878308c4f839726a767190c21de0a17 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/x4-upscaling.yaml b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45ecbf9ad863b331f36fa28360afe8e9756883a6 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml @@ -0,0 +1,75 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Teyvat/README.md b/examples/images/diffusion/configs/Teyvat/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6a7ee88e5c043a391c43b8d1d73f8c1371c5ce4b --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/README.md @@ -0,0 +1,25 @@ +# Dataset Card for Teyvat BLIP captions +Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion). + +BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2). + +For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided. + +The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP). +## Examples + + + +> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes + + + +> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes + + + +> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:a anime girl with long white hair and blue eyes + + + +> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:an anime character wearing a purple dress and cat ears diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9048b3f80d1f21f3c07b2d177b9c510ed7f2cc8b --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -0,0 +1,126 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 16 + num_workers: 4 + train: + target: ldm.data.teyvat.hf_dataset + params: + path: Fazzie/Teyvat + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..155b26dd49cd9bd9dd782ade7a720f6f493a61d9 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -0,0 +1,120 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: False + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: "/data/scratch/diffuser/laion_part0/" + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 1 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5335bacbef6dfedadd1b8cc64eeca1e3f66e74b4 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -0,0 +1,127 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 4 + train: + target: ldm.data.cifar10.hf_dataset + params: + name: cifar10 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 1 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4308998f4bf56b5864b5e5bbf19a9450c6093ed4 --- /dev/null +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -0,0 +1,123 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 16 + num_workers: 4 + train: + target: ldm.data.teyvat.hf_dataset + params: + path: Fazzie/Teyvat + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.DDPStrategy + params: + find_unused_parameters: False + log_every_n_steps: 2 +# max_steps: 6o + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/data2/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_pokemon.yaml b/examples/images/diffusion/configs/train_pokemon.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38e8485a3937cbb6b946549f47505c9d3cdbe2bf --- /dev/null +++ b/examples/images/diffusion/configs/train_pokemon.yaml @@ -0,0 +1,120 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 32 + wrap: False + train: + target: ldm.data.pokemon.PokemonDataset + # params: + # file_path: "/data/scratch/diffuser/laion_part0/" + # world_size: 1 + # rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 1 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/environment.yaml b/examples/images/diffusion/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b5579211d063ef716f3f73297369161b2ec0b34 --- /dev/null +++ b/examples/images/diffusion/environment.yaml @@ -0,0 +1,30 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - python=3.9.12 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - albumentations==1.3.0 + - opencv-python==4.6.0.66 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.3.0 + - transformers==4.19.2 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.7.0 + - prefetch_generator + - datasets + - -e . diff --git a/examples/images/diffusion/ldm/data/__init__.py b/examples/images/diffusion/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/data/base.py b/examples/images/diffusion/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3cd35714a02d087e0a19ffd5f91ff514689ab9 --- /dev/null +++ b/examples/images/diffusion/ldm/data/base.py @@ -0,0 +1,75 @@ +import math +from abc import abstractmethod + +import torch +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset +import os +import numpy as np +import cv2 + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, file_path: str, rank, world_size): + super().__init__() + self.file_path = file_path + self.folder_list = [] + self.file_list = [] + self.txt_list = [] + self.info = self._get_file_info(file_path) + self.start = self.info['start'] + self.end = self.info['end'] + self.rank = rank + + self.world_size = world_size + # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size))) + # self.iter_start = self.start + self.rank * self.per_worker + # self.iter_end = min(self.iter_start + self.per_worker, self.end) + # self.num_records = self.iter_end - self.iter_start + # self.valid_ids = [i for i in range(self.iter_end)] + self.num_records = self.end - self.start + self.valid_ids = [i for i in range(self.end)] + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + # return self.iter_end - self.iter_start + return self.end - self.start + + def __iter__(self): + sample_iterator = self._sample_generator(self.start, self.end) + # sample_iterator = self._sample_generator(self.iter_start, self.iter_end) + return sample_iterator + + def _sample_generator(self, start, end): + for idx in range(start, end): + file_name = self.file_list[idx] + txt_name = self.txt_list[idx] + f_ = open(txt_name, 'r') + txt_ = f_.read() + f_.close() + image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.from_numpy(image) / 255 + yield {"caption": txt_, "image":image} + + + def _get_file_info(self, file_path): + info = \ + { + "start": 1, + "end": 0, + } + self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + for folder in self.folder_list: + files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] + txts = [k.replace('jpg', 'txt') for k in files] + self.file_list.extend(files) + self.txt_list.extend(txts) + info['end'] = len(self.file_list) + # with open(file_path, 'r') as fin: + # for _ in enumerate(fin): + # info['end'] += 1 + # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list] + return info \ No newline at end of file diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..53cd61263b472d37c6b2b896cfd8ba2f89477b9a --- /dev/null +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -0,0 +1,184 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + name, + image_transforms=[], + image_column="img", + label_column="label", + text_column="txt", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + + label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + + processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] + + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790 --- /dev/null +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e --- /dev/null +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py new file mode 100644 index 0000000000000000000000000000000000000000..61dc29d56e7c0c1e9eb3df6abc8dca49438be1b6 --- /dev/null +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -0,0 +1,152 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + path = "Fazzie/Teyvat", + image_transforms=[], + image_column="image", + text_column="text", + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(path, name="train") + ds = ds["train"] + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] + ) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + processed[caption_key] = examples[text_column] + + return processed + + ds.set_transform(pre_process) + return ds \ No newline at end of file diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b1bd8377835bfa1df71178d21dcb245241444a17 --- /dev/null +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -0,0 +1,223 @@ +import torch +try: + import lightning.pytorch as pl +except: + import pytorch_lightning as pl + +import torch.nn.functional as F +from contextlib import contextmanager + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config +from ldm.modules.ema import LitEma + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/examples/images/diffusion/ldm/models/diffusion/__init__.py b/examples/images/diffusion/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..612a8371bf201c33988b4df895883b03233f3859 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import lightning.pytorch as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -0,0 +1,336 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ac0a735f10ab5a4132fe136215a0506c266a3c --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1895 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +try: + import lightning.pytorch as pl + from lightning.pytorch.utilities import rank_zero_only, rank_zero_info +except: + import pytorch_lightning as pl + from pytorch_lightning.utilities import rank_zero_only, rank_zero_info +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid + +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL + + +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import * + +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d +from ldm.modules.encoders.modules import * + +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import * +from ldm.models.diffusion.ddim import * +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.model import * + + +from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder + +from ldm.util import instantiate_from_config + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + use_fp16 = True, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + + self.unet_config = unet_config + self.conditioning_key = conditioning_key + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + self.ckpt_path = ckpt_path + self.ignore_keys = ignore_keys + self.load_only_unet = load_only_unet + self.reset_ema = reset_ema + self.reset_num_ema_updates = reset_num_ema_updates + + if reset_ema: assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + if reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.given_betas = given_betas + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.logvar_init = logvar_init + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + self.use_fp16 = use_fp16 + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + if self.use_fp16: + x = x.to(memory_format=torch.contiguous_format).half() + else: + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + force_null_conditioning=False, + *args, **kwargs): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = None + + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + if self.reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def configure_sharded_model(self) -> None: + rank_zero_info("Configure sharded model for LatentDiffusion") + self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) + if self.use_ema: + self.model_ema = LitEma(self.model) + + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys, only_model=self.load_only_unet) + if self.reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s) + + self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + self.instantiate_first_stage(self.first_stage_config) + self.instantiate_cond_stage(self.cond_stage_config) + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + if self.reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z.half() if self.use_fp16 else self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, return_x=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: + xc = batch[cond_key] + elif cond_key in ['class_label', 'cls']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + + from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) + + # opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + if self.use_fp16: + x_low = x_low.to(memory_format=torch.contiguous_format).half() + else: + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, **kwargs + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + if self.use_fp16: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + else: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, + low_scale_config=None, low_scale_key=None, *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Pichรฉ-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -0,0 +1,244 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d504d939f6a02cf45f028799d7d73b84500cee06 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -0,0 +1,331 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/images/diffusion/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..57b9a4b80f4bd8fa7c360c73ce55fe13349a8113 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,857 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +try: + from lightning.pytorch.utilities import rank_zero_info +except: + from pytorch_lightning.utilities import rank_zero_info + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + rank_zero_info("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..cd639d9360466c72b92db403e52514a149997ed8 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,787 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = t_emb.type(self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e0621032dfc14353abfbc96b5030128ad650160f --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return nn.GroupNorm(16, channels) + # return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/distributions/__init__.py b/examples/images/diffusion/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/examples/images/diffusion/ldm/modules/encoders/__init__.py b/examples/images/diffusion/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b --- /dev/null +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + diff --git a/examples/images/diffusion/ldm/modules/image_degradation/__init__.py b/examples/images/diffusion/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,651 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + if up: + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png b/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png differ diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2๏ผŒcmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8c09ca1c72f7ceb3f9d7f9546aae5561baf62b13 --- /dev/null +++ b/examples/images/diffusion/ldm/util.py @@ -0,0 +1,197 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py new file mode 100644 index 0000000000000000000000000000000000000000..87d4951237145938d6e42ad02b1fe2079b02b2a7 --- /dev/null +++ b/examples/images/diffusion/main.py @@ -0,0 +1,826 @@ +import argparse +import csv +import datetime +import glob +import importlib +import os +import sys +import time + +import numpy as np +import torch +import torchvision + +try: + import lightning.pytorch as pl +except: + import pytorch_lightning as pl + +from functools import partial + +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset, Subset, random_split + +try: + from lightning.pytorch import seed_everything + from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from lightning.pytorch.trainer import Trainer + from lightning.pytorch.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "lightning.pytorch." +except: + from pytorch_lightning import seed_everything + from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from pytorch_lightning.trainer import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "pytorch_lightning." + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config + +# from ldm.modules.attention import enable_flash_attentions + + +class DataLoaderX(DataLoader): + + def __iter__(self): + return BackgroundGenerator(super().__iter__()) + + +def get_parser(**parser_kwargs): + + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument("-p", "--project", help="name of new or path to existing project") + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "--use_fp16", + type=str2bool, + nargs="?", + const=True, + default=True, + help="whether to use fp16", + ) + parser.add_argument( + "--flash", + type=str2bool, + const=True, + default=False, + nargs="?", + help="whether to use flash attention", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + + def __init__(self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoaderX(self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn) + + +class SetupCallback(Callback): + + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + # def on_pretrain_routine_start(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + + def __init__(self, + batch_frequency, + max_images, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.CSVLogger: self._testtube, + } + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, + batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or + (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + # self.log_img(pl_module, batch, batch_idx, split="train") + pass + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + + def on_train_start(self, trainer, pl_module): + rank_zero_info("Training is starting") + + def on_train_end(self, trainer, pl_module): + rank_zero_info("Training is ending") + + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.strategy.reduce(max_memory) + epoch_time = trainer.strategy.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + # custom parser to specify config files, train, test and debug mode, + # postfix, resume. + # `--key value` arguments are interpreted as arguments to the trainer. + # `nested.key=value` arguments are interpreted as config parameters. + # configs are merged from left-to-right followed by command line parameters. + + # model: + # base_learning_rate: float + # target: path to lightning module + # params: + # key: value + # data: + # target: main.DataModuleFromConfig + # params: + # batch_size: int + # wrap: bool + # train: + # target: path to train dataset + # params: + # key: value + # validation: + # target: path to validation dataset + # params: + # key: value + # test: + # target: path to test dataset + # params: + # key: value + # lightning: (optional, has sane defaults and can be specified on cmdline) + # trainer: + # additional arguments to trainer + # logger: + # logger to instantiate + # modelcheckpoint: + # modelcheckpoint to instantiate + # callbacks: + # callback1: + # target: importpath + # params: + # key: value + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + if opt.name and opt.resume: + raise ValueError("-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint") + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + seed_everything(opt.seed) + + try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + + print(trainer_config) + if not trainer_config["accelerator"] == "gpu": + del trainer_config["accelerator"] + cpu = True + print("Running on CPU") + else: + cpu = False + print("Running on GPU") + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + use_fp16 = trainer_config.get("precision", 32) == 16 + if use_fp16: + config.model["params"].update({"use_fp16": True}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + else: + config.model["params"].update({"use_fp16": False}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + + model = instantiate_from_config(config.model) + # trainer and callbacks + trainer_kwargs = dict() + + # config the logger + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + }, + "tensorboard": { + "target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", + "params": { + "save_dir": logdir, + "name": "diff_tb", + "log_graph": True + } + } + } + + default_logger_cfg = default_logger_cfgs["tensorboard"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = default_logger_cfg + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # config the strategy, defualt is ddp + if "strategy" in trainer_config: + strategy_cfg = trainer_config["strategy"] + print("Using strategy: {}".format(strategy_cfg["target"])) + strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] + else: + strategy_cfg = { + "target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", + "params": { + "find_unused_parameters": False + } + } + print("Using strategy: DDPStrategy") + + trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) + + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 750, + "max_images": 4, + "clamp": True + } + }, + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "main.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': { + "target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = trainer_config["devices"] + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" + .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb + pudb.set_trace() + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # run + if opt.train: + try: + trainer.fit(model, data) + except Exception: + melk() + raise + # if not opt.no_test and not trainer.interrupted: + # trainer.test(model, data) + except Exception: + if opt.debug and trainer.global_rank == 0: + try: + import pudb as debugger + except ImportError: + import pdb as debugger + debugger.post_mortem() + raise + finally: + # move newly created debug project to debug_runs + if opt.debug and not opt.resume and trainer.global_rank == 0: + dst, name = os.path.split(logdir) + dst = os.path.join(dst, "debug_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + os.rename(logdir, dst) + if trainer.global_rank == 0: + print(trainer.profiler.summary()) diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a83b2aa3e336c04fdd06db3a9c5b4288581b1bc --- /dev/null +++ b/examples/images/diffusion/requirements.txt @@ -0,0 +1,17 @@ +albumentations==1.3.0 +opencv-python +pudb==2019.2 +prefetch_generator +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +torchmetrics==0.6 +omegaconf==2.1.1 +test-tube>=0.7.5 +streamlit>=0.73.1 +einops==0.3.0 +transformers==4.19.2 +webdataset==0.2.5 +open-clip-torch==2.7.0 +gradio==3.11 +datasets +-e . diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh new file mode 100644 index 0000000000000000000000000000000000000000..a8d79e99ccdff0a8d8762f23f3c0642401f32f6c --- /dev/null +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -0,0 +1,41 @@ +#!/bin/bash +wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip +wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip +wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip +wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip +wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip +wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip +wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip + + + +cd models/first_stage_models/kl-f4 +unzip -o model.zip + +cd ../kl-f8 +unzip -o model.zip + +cd ../kl-f16 +unzip -o model.zip + +cd ../kl-f32 +unzip -o model.zip + +cd ../vq-f4 +unzip -o model.zip + +cd ../vq-f4-noattn +unzip -o model.zip + +cd ../vq-f8 +unzip -o model.zip + +cd ../vq-f8-n256 +unzip -o model.zip + +cd ../vq-f16 +unzip -o model.zip + +cd ../.. \ No newline at end of file diff --git a/examples/images/diffusion/scripts/download_models.sh b/examples/images/diffusion/scripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..84297d7b8b9a78d241edcd5adaf7d9aa273790de --- /dev/null +++ b/examples/images/diffusion/scripts/download_models.sh @@ -0,0 +1,49 @@ +#!/bin/bash +wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip +wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip +wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip +wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip +wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip +wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip +wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip +wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip +wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip + + + +cd models/ldm/celeba256 +unzip -o celeba-256.zip + +cd ../ffhq256 +unzip -o ffhq-256.zip + +cd ../lsun_churches256 +unzip -o lsun_churches-256.zip + +cd ../lsun_beds256 +unzip -o lsun_beds-256.zip + +cd ../text2img256 +unzip -o model.zip + +cd ../cin256 +unzip -o model.zip + +cd ../semantic_synthesis512 +unzip -o model.zip + +cd ../semantic_synthesis256 +unzip -o model.zip + +cd ../bsr_sr +unzip -o model.zip + +cd ../layout2img-openimages256 +unzip -o model.zip + +cd ../inpainting_big +unzip -o model.zip + +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ccfa259c10bdd0a56645d8a1fc5f501f9986cf --- /dev/null +++ b/examples/images/diffusion/scripts/img2img.py @@ -0,0 +1,282 @@ +"""make variations of input image""" + +import argparse, os +import PIL +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +from torch import autocast +from contextlib import nullcontext +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from imwatermark import WatermarkEncoder + + +from scripts.txt2img import put_watermark +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2. * image - 1. + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--init-img", + type=str, + nargs="?", + help="path to the input image" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=2, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.8, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v2-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + + opt = parser.parse_args() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + assert os.path.isfile(opt.init_img) + init_image = load_img(opt.init_img).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + + assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(opt.strength * opt.ddim_steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e6387a9a3b0afa73fae8af25f43a8ba856240e --- /dev/null +++ b/examples/images/diffusion/scripts/inpaint.py @@ -0,0 +1,98 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + opt = parser.parse_args() + + masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) + images = [x.replace("_mask.png", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], + size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1]-1,)+c.shape[2:] + samples_ddim, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eaaecab53eac9c97051c9a5cb457a240679725 --- /dev/null +++ b/examples/images/diffusion/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..876fe3c3642fcc8c7209e4f763c0134166615f78 --- /dev/null +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a32e66d44cf2479d4dcc05d469cf7b4210d2c67d --- /dev/null +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -0,0 +1,37 @@ +import os +import sys +from copy import deepcopy + +import yaml +from datetime import datetime + +from diffusers import StableDiffusionPipeline +import torch +from ldm.util import instantiate_from_config +from main import get_parser + +if __name__ == "__main__": + with torch.no_grad(): + yaml_path = "../../train_colossalai.yaml" + with open(yaml_path, 'r', encoding='utf-8') as f: + config = f.read() + base_config = yaml.load(config, Loader=yaml.FullLoader) + unet_config = base_config['model']['params']['unet_config'] + diffusion_model = instantiate_from_config(unet_config).to("cuda:0") + + pipe = StableDiffusionPipeline.from_pretrained( + "/data/scratch/diffuser/stable-diffusion-v1-4" + ).to("cuda:0") + dif_model_2 = pipe.unet + + random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") + random_input_2 = torch.clone(random_input_).to("cuda:0") + time_stamp = torch.randint(20, (4,)).to("cuda:0") + time_stamp2 = torch.clone(time_stamp).to("cuda:0") + context_ = torch.rand((4, 77, 768)).to("cuda:0") + context_2 = torch.clone(context_).to("cuda:0") + + out_1 = diffusion_model(random_input_, time_stamp, context_) + out_2 = dif_model_2(random_input_2, time_stamp2, context_2) + print(out_1.shape) + print(out_2['sample'].shape) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..f93f8a6e70763c0e284157bc8225827520b2f5ef --- /dev/null +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -0,0 +1,18 @@ +import cv2 +import fire +from imwatermark import WatermarkDecoder + + +def testit(img_path): + bgr = cv2.imread(img_path) + decoder = WatermarkDecoder('bytes', 136) + watermark = decoder.decode(bgr, 'dwtDct') + try: + dec = watermark.decode('utf-8') + except: + dec = "null" + print(dec) + + +if __name__ == "__main__": + fire.Fire(testit) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7904889c0145f9fb740fd4ae8e45c08728b255 --- /dev/null +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py new file mode 100644 index 0000000000000000000000000000000000000000..15993008f179a63bfaba998b5516c452b5d4a7dc --- /dev/null +++ b/examples/images/diffusion/scripts/txt2img.py @@ -0,0 +1,292 @@ +import argparse, os +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import nullcontext +from imwatermark import WatermarkEncoder + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler + +torch.set_grad_enabled(False) + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a professional photograph of an astronaut riding a triceratops", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--dpm", + action='store_true', + help="use DPM (2) sampler", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=3, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file, separated by newlines", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v2-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="repeat each prompt in file this often", + ) + opt = parser.parse_args() + return opt + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def main(opt): + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + elif opt.dpm: + sampler = DPMSolverSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = [p for p in data for i in range(opt.repeat)] + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + sample_count = 0 + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(), \ + precision_scope("cuda"), \ + model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + opt = parse_args() + main(opt) diff --git a/examples/images/diffusion/scripts/txt2img.sh b/examples/images/diffusion/scripts/txt2img.sh new file mode 100644 index 0000000000000000000000000000000000000000..549bb03a6885a4d2662dd53c09b534854f3e228d --- /dev/null +++ b/examples/images/diffusion/scripts/txt2img.sh @@ -0,0 +1,6 @@ +python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \ + --outdir ./output \ + --config /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/checkpoints/last.ckpt \ + --ckpt /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/configs/2022-11-18T16-38-46-project.yaml \ + --n_samples 4 + diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a24d541676407eee1bea271179ffd1d80c6a8e79 --- /dev/null +++ b/examples/images/diffusion/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='latent-diffusion', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) \ No newline at end of file diff --git a/examples/images/diffusion/train.sh b/examples/images/diffusion/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed9ae4b75f6c20417e3d85ad9837640faf209a78 --- /dev/null +++ b/examples/images/diffusion/train.sh @@ -0,0 +1,5 @@ +# HF_DATASETS_OFFLINE=1 +# TRANSFORMERS_OFFLINE=1 +# DIFFUSERS_OFFLINE=1 + +python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f78c037ef990d04505e0f0f6fdf2696b693215ab --- /dev/null +++ b/examples/images/vit/README.md @@ -0,0 +1,61 @@ +# Vision Transformer with ColoTensor + +# Overview + +In this example, we will run Vision Transformer with ColoTensor. + +We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test. +You can change world size or decide whether use DDP in our code. + +We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example. + +(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present. + +# Requirement + +You should install colossalai from main branch with commit 561e904. + +## Unit test +To run unit test, you should install pytest, transformers with: +```shell +pip install pytest transformers +``` + +## Training example +To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support. +You also need to install timm and titans for model/dataloader support with: +```shell +pip install timm titans +``` + +### Data preparation +You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one. + +Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader. +```shell +export DATA=/path/to/ILSVRC2012 +``` + + +# How to run + +## Unit test +In your terminal +```shell +pytest test_vit.py +``` + +This will evaluate models with different **world_size** and **use_ddp**. + +## Training example +Modify the settings in run.sh according to your environment. +For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file, +data parallel size will be automatically calculated as 4. +Thus, the parallel strategy is set to 4DP+2TP. + +Then in your terminal +```shell +sh run.sh +``` + +This will start ViT-S training with ImageNet. diff --git a/examples/images/vit/configs/vit_1d_tp2.py b/examples/images/vit/configs/vit_1d_tp2.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf399f2e50daaa70f52726ccfa6f4e035ce7380 --- /dev/null +++ b/examples/images/vit/configs/vit_1d_tp2.py @@ -0,0 +1,32 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 384 +DEPTH = 12 +NUM_HEADS = 6 +MLP_RATIO = 4 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +USE_DDP = True +TP_WORLD_SIZE = 2 +TP_TYPE = 'row' +parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +gradient_accumulation = 8 + +LOG_PATH = "./log" diff --git a/examples/images/vit/run.sh b/examples/images/vit/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..84fe58f11a6a7c25242eab1714fd94698509be9a --- /dev/null +++ b/examples/images/vit/run.sh @@ -0,0 +1,15 @@ +export DATA=/data/scratch/imagenet/tf_records +export OMP_NUM_THREADS=4 + +# resume +# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ +# --nproc_per_node 4 train.py \ +# --config configs/vit_1d_tp2.py \ +# --resume_from checkpoint/epoch_10 \ +# --master_port 29598 | tee ./out 2>&1 + +# train +CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ +--nproc_per_node 4 train.py \ +--config configs/vit_1d_tp2.py \ +--master_port 29598 | tee ./out 2>&1 diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7dbbe607ef5a7f48a6fd2329b9f9faade10f4109 --- /dev/null +++ b/examples/images/vit/test_vit.py @@ -0,0 +1,132 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from utils.util import set_seed, tensor_equal, tensor_shard_equal +from vit import get_training_components + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. +# But for other layers, it's 1d_col split. +# Layernorm is not supported for now. +# patch_embeddings.projection has nn.Conv2d +# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 +def init_1d_row_for_linear_weight_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n + or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def check_param_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert tensor_shard_equal(torch_p, p) + + +def check_grad_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if (torch_p.grad.shape == p.grad.shape): + assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True + else: + dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True + + +def run_vit(init_spec_func, use_ddp): + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + if use_ddp: + model = ColoDDP(model) + torch_model = DDP(torch_model, + device_ids=[gpc.get_global_rank()], + process_group=gpc.get_group(ParallelMode.DATA)) + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + + world_size = torch.distributed.get_world_size() + init_spec_func(model, world_size) + + check_param_equal(model, torch_model) + model.train() + torch_model.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) + + optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + logits = model(image_dict['pixel_values']) + torch_logits = torch_model(image_dict['pixel_values']) + assert tensor_equal(torch_logits.logits, logits.logits) + loss = criterion(logits.logits, image_dict['label']) + torch_loss = criterion(torch_logits.logits, image_dict['label']) + if use_ddp: + model.backward(loss) + else: + loss.backward() + torch_loss.backward() + check_grad_equal(model, torch_model) + optimizer.step() + torch_optimizer.step() + check_param_equal(model, torch_model) + break + + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit(init_1d_row_for_linear_weight_spec, use_ddp) + run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +def test_vit(world_size, use_ddp): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_vit(1, False) diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py new file mode 100644 index 0000000000000000000000000000000000000000..de39801c79728e507ca452120376ca1af0183a5e --- /dev/null +++ b/examples/images/vit/train.py @@ -0,0 +1,161 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import _create_vision_transformer +from titans.dataloader.imagenet import build_dali_imagenet +from tqdm import tqdm + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn._ops import * +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +def init_1d_row_for_linear_weight_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n + and 'patch_embed.proj.bias' not in n): + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def init_spec_func(model, tp_type): + world_size = torch.distributed.get_world_size() + if tp_type == 'row': + init_1d_row_for_linear_weight_spec(model, world_size) + elif tp_type == 'col': + init_1d_col_for_linear_weight_bias_spec(model, world_size) + else: + raise NotImplemented + + +def train_imagenet(): + + parser = colossalai.get_default_parser() + parser.add_argument('--from_torch', default=True, action='store_true') + parser.add_argument('--resume_from', default=False) + + args = parser.parse_args() + colossalai.launch_from_torch(config=args.config) + use_ddp = gpc.config.USE_DDP + + disable_existing_loggers() + + logger = get_dist_logger() + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + logger.info('Build data loader', ranks=[0]) + root = os.environ['DATA'] + train_dataloader, test_dataloader = build_dali_imagenet(root, + train_batch_size=gpc.config.BATCH_SIZE, + test_batch_size=gpc.config.BATCH_SIZE) + + logger.info('Build model', ranks=[0]) + + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + embed_dim=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=gpc.config.NUM_CLASSES, + drop_rate=0.1, + attn_drop_rate=0.1, + weight_init='jax') + + with ColoInitContext(device=get_current_device()): + model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) + init_spec_func(model, gpc.config.TP_TYPE) + + world_size = torch.distributed.get_world_size() + model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) + logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) + optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + criterion = CrossEntropyLoss() + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + start_epoch = 0 + if args.resume_from: + load_model = torch.load(args.resume_from + '_model.pth') + start_epoch = load_model['epoch'] + model.load_state_dict(load_model['model']) + load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) + optimizer.load_state_dict(load_optim['optim']) + + for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): + model.train() + for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): + x, y = x.cuda(), y.cuda() + output = model(x) + loss = criterion(output, y) + loss = loss / gpc.config.gradient_accumulation + if use_ddp: + model.backward(loss) + else: + loss.backward() + if (index + 1) % gpc.config.gradient_accumulation == 0: + optimizer.step() + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + + logger.info( + f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", + ranks=[0]) + + model.eval() + test_loss = 0 + correct = 0 + test_sum = 0 + with torch.no_grad(): + for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): + x, y = x.cuda(), y.cuda() + output = model(x) + test_loss += F.cross_entropy(output, y, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(y.view_as(pred)).sum().item() + test_sum += y.size(0) + + test_loss /= test_sum + logger.info( + f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", + ranks=[0]) + + lr_scheduler.step() + + +if __name__ == '__main__': + train_imagenet() diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..1116c7416934f5152f358fc2d6539f249874f05f --- /dev/null +++ b/examples/images/vit/vit.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from utils.dummy_data_generator import DummyDataGenerator + +from colossalai.utils.cuda import get_current_device +from transformers import ViTConfig, ViTForImageClassification + + +class DummyDataLoader(DummyDataGenerator): + batch_size = 4 + channel = 3 + category = 8 + image_size = 224 + + def generate(self): + image_dict = {} + image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, + DummyDataLoader.channel, + DummyDataLoader.image_size, + DummyDataLoader.image_size, + device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + dtype=torch.int64, + device=get_current_device()) + return image_dict + + +class ViTCVModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + image_size=224, + patch_size=16, + num_channels=3, + num_labels=8, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = ViTForImageClassification( + ViTConfig(hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + num_labels=num_labels)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, pixel_values): + return self.model(pixel_values=pixel_values) + + +def vit_base_s(checkpoint=True): + return ViTCVModel(checkpoint=checkpoint) + + +def vit_base_micro(checkpoint=True): + return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) + + +def get_training_components(): + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b6b0ddc141fbfbc8facd8d086a34a998ebbfbf6a --- /dev/null +++ b/examples/language/gpt/README.md @@ -0,0 +1,52 @@ +# Train GPT with Colossal-AI + +This example shows how to use [Colossal-AI](https://github.com/hpcaitech/ColossalAI) to run huggingface GPT training in distributed manners. + +## GPT + +We use the [GPT-2](https://huggingface.co/gpt2) model from huggingface transformers. The key learning goal of GPT-2 is to use unsupervised pre-training models to do supervised tasks.GPT-2 has an amazing performance in text generation, and the generated text exceeds people's expectations in terms of contextual coherence and emotional expression. + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.1.11rc5](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.1.11rc5+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.11rc5+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231. + +## Dataset + +For simplicity, the input data is randonly generated here. + +## Training + +```bash +bash run.sh +``` + +### Training config + +The `train_gpt_demo.py` provides three distributed plans, you can choose the plan you want in `run.sh`. The Colossal-AI leverages Tensor Parallel and Gemini + ZeRO DDP. + +- Colossal-AI +- PyTorch DDP +- ZeRO \ No newline at end of file diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..208a31ebba320486b3f7fc359dc7f8045c81d972 --- /dev/null +++ b/examples/language/gpt/requirements.txt @@ -0,0 +1,3 @@ +colossalai >= 0.1.10 +torch >= 1.8.1 +transformers >= 4.231 diff --git a/examples/language/gpt/run.sh b/examples/language/gpt/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a4b5ce14b8acb6486f80624def5b49dd8ee6991 --- /dev/null +++ b/examples/language/gpt/run.sh @@ -0,0 +1,10 @@ +# distplan in ["colossalai", "zero", "ddp"] +export DISTPAN="colossalai" + +# The following options only valid when DISTPAN="colossalai" +export TPDEGREE=2 +export GPUNUM=4 +export PLACEMENT='cpu' +export USE_SHARD_INIT=False + +env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/gpt/train_gpt_demo.py b/examples/language/gpt/train_gpt_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..92123e6a716f3826b9757dfd91b6af507dfbe278 --- /dev/null +++ b/examples/language/gpt/train_gpt_demo.py @@ -0,0 +1,285 @@ +from functools import partial +from time import time + +import psutil +import torch +import torch.nn as nn +from packaging import version +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from transformers import GPT2Config, GPT2LMHeadModel + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='colossalai', + help="The distributed plan [colossalai, ddp, zero].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + type=bool, + default=False, + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + args = parser.parse_args() + return args + + +## Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +## Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # NOTE() a param maybe shared by tow modules + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice + else: + param.set_dist_spec(ReplicaSpec()) + + param.visited = True + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model + + +def main(): + args = parse_args() + + BATCH_SIZE = 8 + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + + logger = get_dist_logger() + logger.info(f"using dist plan {args.distplan}", ranks=[0]) + + # build criterion + criterion = GPTLMLoss() + + torch.manual_seed(123) + if args.distplan == "colossalai": + # all param must use the same process group. + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + + # build GPT model + with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + model = gpt2_medium(checkpoint=True) + + pg = default_pg + # Tensor Parallelism (TP) + tensor_parallelize(model, pg) + # Gemini + ZeRO DP, Note it must be used after TP + model = gemini_zero_dpp(model, pg, args.placement) + + # build optimizer + optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) + # optimizer = HybridAdam(model.parameters(), lr=1e-3) + # optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + + elif args.distplan == "ddp": + model = gpt2_medium(checkpoint=True).cuda() + ddp_model = DDP(model) + optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01) + + elif args.distplan == "zero": + from torch.distributed.optim import ZeroRedundancyOptimizer + model = gpt2_medium(checkpoint=True).cuda() + ddp_model = DDP(model) + optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) + else: + raise TypeError(f"{args.distplan} is error") + + numel = sum([p.numel() for p in model.parameters()]) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + torch.cuda.synchronize() + model.train() + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) + if args.distplan == "colossalai": + optimizer.backward(loss) + elif args.distplan in ["ddp", "zero"]: + loss.backward() + + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) + optimizer.step() + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + step_time = time() - start + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..75573b70919a0e40975965bf9087d076fb6142eb --- /dev/null +++ b/examples/language/opt/README.md @@ -0,0 +1,52 @@ + + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## Quick Start +You can launch training by using the following bash script + +```bash +bash ./run_clm.sh +``` + +- batch-size-per-gpu: number of samples fed to each GPU, default is 16 +- mem-cap: limit memory usage within a value in GB, default is 0 (no limit) +- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request +the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT). +- gpu-num: the number of GPUs to use, default is 1. + +## Remarkable Performance +On a single GPU, Colossal-AIโ€™s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed. +Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale. + +

+ +

+ +Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI! + +More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d), +and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon. diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..f02f7629ad16f42fc04f131c26869477fc59aa90 --- /dev/null +++ b/examples/language/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM +done +done +done +done diff --git a/examples/language/opt/colossalai_zero.py b/examples/language/opt/colossalai_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..833745f3e8d84ef76305ebbecc09822746243be3 --- /dev/null +++ b/examples/language/opt/colossalai_zero.py @@ -0,0 +1,6 @@ +from colossalai.zero.shard_utils import TensorShardStrategy + +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + tensor_placement_policy="auto", + reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) diff --git a/examples/language/opt/context.py b/examples/language/opt/context.py new file mode 100644 index 0000000000000000000000000000000000000000..95f0abf1d8c92ed5766e5f0fa2c70618be7827c5 --- /dev/null +++ b/examples/language/opt/context.py @@ -0,0 +1,32 @@ +import torch.distributed as dist + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class barrier_context(): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + parallel_mode (ParallelMode): the parallel mode corresponding to a process group + Usage: + with barrier_context(): + dataset = CIFAR10(root='./data', download=True) + """ + + def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): + # the class name is lowercase by convention + current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) + self.should_block = current_rank != executor_rank + self.group = gpc.get_group(parallel_mode=parallel_mode) + + def __enter__(self): + if self.should_block: + dist.barrier(group=self.group) + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.should_block: + dist.barrier(group=self.group) diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c34df7992d3f1b6fa19cd58ac36c6a9b3499c75f --- /dev/null +++ b/examples/language/opt/requirements.txt @@ -0,0 +1,6 @@ +colossalai +torch >= 1.8.1 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +accelerate == 0.13.2 diff --git a/examples/language/opt/run_clm.py b/examples/language/opt/run_clm.py new file mode 100644 index 0000000000000000000000000000000000000000..c6590323e3a47fdf121c1406798322f0e5151977 --- /dev/null +++ b/examples/language/opt/run_clm.py @@ -0,0 +1,596 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import math +import os +import time +from itertools import chain + +import datasets +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from context import barrier_context +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +import colossalai +import transformers +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoTokenizer, + GPT2Tokenizer, + OPTForCausalLM, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils.versions import require_version + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=None, + help="A csv or a json file containing the training data.") + parser.add_argument("--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data.") + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the ๐Ÿค— Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument("--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)."), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument("--overwrite_cache", + type=bool, + default=False, + help="Overwrite the cached training and evaluation sets") + parser.add_argument("--no_keep_linebreaks", + action="store_true", + help="Do not keep line breaks when using TXT files.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed."), + ) + + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch(config=dict()) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # Handle the repository creation + with barrier_context(): + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + logger.info("Start preparing dataset", ranks=[0]) + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + logger.info("Dataset is prepared", ranks=[0]) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.model_name_or_path == 'facebook/opt-13b': + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + else: + print(f'load model from {args.model_name_or_path}') + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + PLACEMENT_POLICY = 'auto' + cai_version = colossalai.__version__ + logger.info(f'using Colossal-AI version {cai_version}') + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i:i + block_size] for i in range(0, total_length, block_size) + ] for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = get_dataloader(train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, + collate_fn=default_data_collator, + batch_size=args.per_device_eval_batch_size) + logger.info("Dataloaders have been created", ranks=[0]) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA) + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + completed_steps = 0 + starting_epoch = 0 + global_step = 0 + + for epoch in range(starting_epoch, args.num_train_epochs): + + if completed_steps >= args.max_train_steps: + break + + model.train() + for step, batch in enumerate(train_dataloader): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + loss = outputs['loss'] + optimizer.backward(loss) + + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + global_step += 1 + logger.info("Global step {} finished".format(global_step + 1), ranks=[0]) + + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + + loss = outputs['loss'].unsqueeze(0) + losses.append(loss) + + losses = torch.cat(losses) + losses = losses[:len(eval_dataset)] + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) + + if args.output_dir is not None: + model_state = model.state_dict() + if is_main_process: + torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + dist.barrier() + # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + # model.load_state_dict(load_state, strict=False) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/opt/run_clm.sh b/examples/language/opt/run_clm.sh new file mode 100644 index 0000000000000000000000000000000000000000..858d3325a7b4f9f55592ac4a7d62836f2a0a0501 --- /dev/null +++ b/examples/language/opt/run_clm.sh @@ -0,0 +1,22 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/roberta/README.md b/examples/language/roberta/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c119d23b5824e43c92e0ed2d76dac8c97895d778 --- /dev/null +++ b/examples/language/roberta/README.md @@ -0,0 +1,58 @@ +# Introduction +This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert. + +## 0. Prerequisite +- Install Colossal-AI +- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes" +- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n2 times + +``` +ssh-keygen +ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination +``` + +- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. + +```bash +192.168.2.1 GPU001 +192.168.2.2 GPU002 +192.168.2.3 GPU003 +192.168.2.4 GPU004 +192.168.2.5 GPU005 +192.168.2.6 GPU006 +192.168.2.7 GPU007 +... +``` + +- restart ssh +``` +service ssh restart +``` + +## 1. Corpus Preprocessing +```bash +cd preprocessing +``` +following the `README.md`, preprocess orginal corpus to h5py+numpy + +## 2. Pretrain + +```bash +cd pretraining +``` +following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model + +## 3. Finetune + +The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from HuggingFace to finetune downstream application. + +## Contributors +The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! + +``` +@misc{ + title={A simple Chinese RoBERTa Example for Whole Word Masked}, + author={Yehua Zhang, Chen Zhang}, + year={2022} +} +``` \ No newline at end of file diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..c3c59aa4079c12ca638d08f7f234adae421f1ed4 --- /dev/null +++ b/examples/language/roberta/configs/colossalai_ddp.py @@ -0,0 +1,4 @@ +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.nn.optimizer import FusedAdam + +clip_grad_norm = 1.0 diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..c5debdce0988110b2e3e858aeaec9a4958ed05d2 --- /dev/null +++ b/examples/language/roberta/configs/colossalai_zero.py @@ -0,0 +1,32 @@ +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.nn.optimizer import FusedAdam + +# fp16 = dict( +# mode=AMP_TYPE.TORCH, +# ) + +# seed = 2 +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + tensor_placement_policy="cuda", + gradient_predivide_factor=1.0, + reuse_fp16_shard=False), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, + initial_scale=2**5, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32)) + +# gradient_accumulation = 4 +clip_grad_norm = 1.0 +optimizer = dict( + type=FusedAdam, + lr=0.00015, + weight_decay=1e-2, +) + +# 64433 \ No newline at end of file diff --git a/examples/language/roberta/preprocessing/Makefile b/examples/language/roberta/preprocessing/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..82ee4e1c5b31bf53cc859c970268eb5070aa107f --- /dev/null +++ b/examples/language/roberta/preprocessing/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = mask +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/language/roberta/preprocessing/README.md b/examples/language/roberta/preprocessing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1dbd745ab9bd528756b501b042cf95cf4a476c9c --- /dev/null +++ b/examples/language/roberta/preprocessing/README.md @@ -0,0 +1,105 @@ +# Data PreProcessing for chinese Whole Word Masked + + + +## Catalogue: +* 1. Introduction +* 2. Quick Start Guide: + * 2.1. Split Sentence + * 2.2.Tokenizer & Whole Word Masked + + + + +## 1. Introduction: [Back to Top] +This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask). + + + +## 2. Quick Start Guide: [Back to Top] + + + +### 2.1. Split Sentence & Split data into multiple shard: +Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `ใ€‚๏ผ`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. +In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** + +```python +python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100 +# This step takes a short time +``` +* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ... +* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--shard`: Number of shard, e.g., 10, 50, or 100 + +Input json: + +``` +[ + { + "id": 0, + "title": "ๆ‰“็ฏฎ็ƒ", + "content": "ๆˆ‘ไปŠๅคฉๅŽปๆ‰“็ฏฎ็ƒใ€‚ไธๅ›žๆฅๅƒ้ฅญใ€‚" + } + { + "id": 1, + "title": "ๆ—…ๆธธ", + "content": "ๆˆ‘ๅŽๅคฉๅŽปๆ—…ๆธธใ€‚ไธ‹ๅ‘จ่ฏทๅ‡ใ€‚" + } +] +``` + +Output txt: + +``` +ๆˆ‘ไปŠๅคฉๅŽปๆ‰“็ฏฎ็ƒใ€‚ +ไธๅ›žๆฅๅƒ้ฅญใ€‚ +]] +ๆˆ‘ๅŽๅคฉๅŽปๆ—…ๆธธใ€‚ +ไธ‹ๅ‘จ่ฏทๅ‡ใ€‚ +``` + + + +### 2.2. Tokenizer & Whole Word Masked: + +```python +python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python +# This step is time consuming and is mainly spent on mask +``` + +**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed + +```shell +make +``` + +* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) +* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** +* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document +* `--worker`: number of process + +Input txt: + +``` +ๆˆ‘ไปŠๅคฉๅŽปๆ‰“็ฏฎ็ƒใ€‚ +ไธๅ›žๆฅๅƒ้ฅญใ€‚ +]] +ๆˆ‘ๅŽๅคฉๅŽปๆ—…ๆธธใ€‚ +ไธ‹ๅ‘จ่ฏทๅ‡ใ€‚ +``` + +Output h5+numpy: + +``` +'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..], + ...] +'input_mask': [[1,1,1,1,1,1,0,0..], + ...] +'segment_ids': [[0,0,0,0,0,...], + ...] +'masked_lm_positions': [[label1,-1,-1,label2,-1...], + ...] +``` \ No newline at end of file diff --git a/examples/language/roberta/preprocessing/get_mask.py b/examples/language/roberta/preprocessing/get_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..da297f98e6c924a081dbf286ed52d7ae0e9b8ca8 --- /dev/null +++ b/examples/language/roberta/preprocessing/get_mask.py @@ -0,0 +1,266 @@ +import torch +import os +from enum import IntEnum +from random import choice +import random +import collections +import time +import logging +import jieba +jieba.setLogLevel(logging.CRITICAL) +import re +import numpy as np +import mask + +PAD = 0 +MaskedLMInstance = collections.namedtuple("MaskedLMInstance", + ["index", "label"]) + + +def map_to_numpy(data): + return np.asarray(data) + + +class PreTrainingDataset(): + def __init__(self, + tokenizer, + max_seq_length, + backend='python', + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True): + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.masked_lm_prob = 0.15 + self.backend = backend + self.do_whole_word_mask = do_whole_word_mask + self.max_predictions_per_seq = max_predictions_per_seq + self.vocab_words = list(tokenizer.vocab.keys()) + self.rec = re.compile('[\u4E00-\u9FA5]') + self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + + self.mlm_p = 0.15 + self.mlm_mask_p = 0.8 + self.mlm_tamper_p = 0.05 + self.mlm_maintain_p = 0.1 + + + def tokenize(self, doc): + temp = [] + for d in doc: + temp.append(self.tokenizer.tokenize(d)) + return temp + + + def create_training_instance(self, instance): + is_next = 1 + raw_text_list = self.get_new_segment(instance) + tokens_a = raw_text_list + assert len(tokens_a) == len(instance) + # tokens_a, tokens_b, is_next = instance.get_values() + # print(f'is_next label:{is_next}') + # Create mapper + tokens = [] + original_tokens = [] + segment_ids = [] + tokens.append("[CLS]") + original_tokens.append('[CLS]') + segment_ids.append(0) + for index, token in enumerate(tokens_a): + tokens.append(token) + original_tokens.append(instance[index]) + segment_ids.append(0) + + tokens.append("[SEP]") + original_tokens.append('[SEP]') + segment_ids.append(0) + + # for token in tokens_b: + # tokens.append(token) + # segment_ids.append(1) + + # tokens.append("[SEP]") + # segment_ids.append(1) + + # Get Masked LM predictions + if self.backend == 'c++': + output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words, + self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob) + elif self.backend == 'python': + output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) + + # Convert to Ids + input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens) + input_mask = [1] * len(input_ids) + + while len(input_ids) < self.max_seq_length: + input_ids.append(PAD) + segment_ids.append(PAD) + input_mask.append(PAD) + masked_lm_output.append(-1) + return ([ + map_to_numpy(input_ids), + map_to_numpy(input_mask), + map_to_numpy(segment_ids), + map_to_numpy(masked_lm_output), + map_to_numpy([is_next]) + ]) + + + def create_masked_lm_predictions(self, tokens): + cand_indexes = [] + for i, token in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and + token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + # cand_indexes.append(i) + + random.shuffle(cand_indexes) + output_tokens = list(tokens) + + num_to_predict = min( + self.max_predictions_per_seq, + max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + if index in covered_indexes: + continue + covered_indexes.add(index) + + masked_token = None + # 80% mask + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% Keep Original + if random.random() < 0.5: + masked_token = tokens[index] + # 10% replace w/ random word + else: + masked_token = self.vocab_words[random.randint( + 0, + len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + masked_lms.append( + MaskedLMInstance(index=index, label=tokens[index])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) + + + def get_new_segment(self, segment): + """ + ่พ“ๅ…ฅไธ€ๅฅ่ฏ๏ผŒ่ฟ”ๅ›žไธ€ๅฅ็ป่ฟ‡ๅค„็†็š„่ฏ: ไธบไบ†ๆ”ฏๆŒไธญๆ–‡ๅ…จ็งฐmask๏ผŒๅฐ†่ขซๅˆ†ๅผ€็š„่ฏ๏ผŒๅฐ†ไธŠ็‰นๆฎŠๆ ‡่ฎฐ("#")๏ผŒไฝฟๅพ—ๅŽ็ปญๅค„็†ๆจกๅ—๏ผŒ่ƒฝๅคŸ็Ÿฅ้“ๅ“ชไบ›ๅญ—ๆ˜ฏๅฑžไบŽๅŒไธ€ไธช่ฏ็š„ใ€‚ + :param segment: ไธ€ๅฅ่ฏ + :return: ไธ€ๅฅๅค„็†่ฟ‡็š„่ฏ + """ + seq_cws = jieba.lcut(''.join(segment)) + seq_cws_dict = {x: 1 for x in seq_cws} + new_segment = [] + i = 0 + while i < len(segment): + if len(self.rec.findall(segment[i])) == 0: # ไธๆ˜ฏไธญๆ–‡็š„๏ผŒๅŽŸๆ–‡ๅŠ ่ฟ›ๅŽปใ€‚ + new_segment.append(segment[i]) + i += 1 + continue + + has_add = False + for length in range(3, 0, -1): + if i + length > len(segment): + continue + if ''.join(segment[i: i+length]) in seq_cws_dict: + new_segment.append(segment[i]) + for l in range(1, length): + new_segment.append('##' + segment[i+l]) + i += length + has_add = True + break + if not has_add: + new_segment.append(segment[i]) + i += 1 + return new_segment + + + def create_whole_masked_lm_predictions(self, tokens): + """Creates the predictions for the masked LM objective.""" + + cand_indexes = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and + token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + random.shuffle(cand_indexes) + + output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # ๅŽปๆމ"##" + + num_to_predict = min(self.max_predictions_per_seq, + max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random.random() < 0.5: + masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # ๅŽปๆމ"##" + # 10% of the time, replace with random word + else: + masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index])) + assert len(masked_lms) <= num_to_predict + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) diff --git a/examples/language/roberta/preprocessing/mask.cpp b/examples/language/roberta/preprocessing/mask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8355c45cff0af50d60d320360f5b27f3dba1aad7 --- /dev/null +++ b/examples/language/roberta/preprocessing/mask.cpp @@ -0,0 +1,184 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +const int32_t LONG_SENTENCE_LEN = 512; + +struct MaskedLMInstance { + int index; + std::string label; + MaskedLMInstance(int index, std::string label) { + this->index = index; + this->label = label; + } +}; + +auto get_new_segment(std::vector segment, std::vector segment_jieba, const std::vector chinese_vocab) { // const std::unordered_set &chinese_vocab + std::unordered_set seq_cws_dict; + for (auto word : segment_jieba) { + seq_cws_dict.insert(word); + } + int i = 0; + std::vector new_segment; + int segment_size = segment.size(); + while (i < segment_size) { + if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end() + new_segment.emplace_back(segment[i]); + i += 1; + continue; + } + bool has_add = false; + for (int length = 3; length >= 1; length--) { + if (i + length > segment_size) { + continue; + } + std::string chinese_word = ""; + for (int j = i; j < i + length; j++) { + chinese_word += segment[j]; + } + if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { + new_segment.emplace_back(segment[i]); + for (int j = i + 1; j < i + length; j++) { + new_segment.emplace_back("##" + segment[j]); + } + i += length; + has_add = true; + break; + } + } + if (!has_add) { + new_segment.emplace_back(segment[i]); + i += 1; + } + } + + return new_segment; +} + +bool startsWith(const std::string& s, const std::string& sub) { + return s.find(sub) == 0 ? true : false; +} + +auto create_whole_masked_lm_predictions(std::vector &tokens, + const std::vector &original_tokens, + const std::vector &vocab_words, + std::map &vocab, + const int max_predictions_per_seq, + const double masked_lm_prob) { + // for (auto item : vocab) { + // std::cout << "key=" << std::string(py::str(item.first)) << ", " + // << "value=" << std::string(py::str(item.second)) << std::endl; + // } + std::vector > cand_indexes; + std::vector cand_temp; + int tokens_size = tokens.size(); + std::string prefix = "##"; + bool do_whole_masked = true; + + for (int i = 0; i < tokens_size; i++) { + if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { + continue; + } + if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) { + cand_temp.emplace_back(i); + } + else { + if (cand_temp.size() > 0) { + cand_indexes.emplace_back(cand_temp); + } + cand_temp.clear(); + cand_temp.emplace_back(i); + } + } + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed)); + // for (auto i : cand_indexes) { + // for (auto j : i) { + // std::cout << tokens[j] << " "; + // } + // std::cout << std::endl; + // } + // for (auto i : output_tokens) { + // std::cout << i; + // } + // std::cout << std::endl; + + int num_to_predict = std::min(max_predictions_per_seq, + std::max(1, int(tokens_size * masked_lm_prob))); + // std::cout << num_to_predict << std::endl; + + std::set covered_indexes; + std::vector masked_lm_output(tokens_size, -1); + int vocab_words_len = vocab_words.size(); + std::default_random_engine e(seed); + std::uniform_real_distribution u1(0.0, 1.0); + std::uniform_int_distribution u2(0, vocab_words_len - 1); + int mask_cnt = 0; + std::vector output_tokens; + output_tokens = original_tokens; + + for (auto index_set : cand_indexes) { + if (mask_cnt > num_to_predict) { + break; + } + int index_set_size = index_set.size(); + if (mask_cnt + index_set_size > num_to_predict) { + continue; + } + bool is_any_index_covered = false; + for (auto index : index_set) { + if (covered_indexes.find(index) != covered_indexes.end()) { + is_any_index_covered = true; + break; + } + } + if (is_any_index_covered) { + continue; + } + for (auto index : index_set) { + + covered_indexes.insert(index); + std::string masked_token; + if (u1(e) < 0.8) { + masked_token = "[MASK]"; + } + else { + if (u1(e) < 0.5) { + masked_token = output_tokens[index]; + } + else { + int random_index = u2(e); + masked_token = vocab_words[random_index]; + } + } + // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); + masked_lm_output[index] = vocab[output_tokens[index]]; + output_tokens[index] = masked_token; + mask_cnt++; + } + } + + // for (auto p : masked_lms) { + // masked_lm_output[p.index] = vocab[p.label]; + // } + return std::make_tuple(output_tokens, masked_lm_output); +} + +PYBIND11_MODULE(mask, m) { + m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions); + m.def("get_new_segment", &get_new_segment); +} diff --git a/examples/language/roberta/preprocessing/sentence_split.py b/examples/language/roberta/preprocessing/sentence_split.py new file mode 100644 index 0000000000000000000000000000000000000000..231be152b067bd6cd15fb7a663cfb1abe8a4dc75 --- /dev/null +++ b/examples/language/roberta/preprocessing/sentence_split.py @@ -0,0 +1,163 @@ + +import multiprocessing +import os +import re +from tqdm import tqdm +from typing import List +import json +import time +import argparse +import functools + +def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: + """ + Args: + document: + flag: Type:str, "all" ไธญ่‹ฑๆ–‡ๆ ‡็‚นๅˆ†ๅฅ๏ผŒ"zh" ไธญๆ–‡ๆ ‡็‚นๅˆ†ๅฅ๏ผŒ"en" ่‹ฑๆ–‡ๆ ‡็‚นๅˆ†ๅฅ + limit: ้ป˜่ฎคๅ•ๅฅๆœ€ๅคง้•ฟๅบฆไธบ510ไธชๅญ—็ฌฆ + Returns: Type:list + """ + sent_list = [] + try: + if flag == "zh": + document = re.sub('(?P([ใ€‚๏ผŸ๏ผโ€ฆ](?![โ€โ€™"\'])))', r'\g\n', document) # ๅ•ๅญ—็ฌฆๆ–ญๅฅ็ฌฆ + document = re.sub('(?P([ใ€‚๏ผŸ๏ผ]|โ€ฆ{1,2})[โ€โ€™"\'])', r'\g\n', document) # ็‰นๆฎŠๅผ•ๅท + elif flag == "en": + document = re.sub('(?P([.?!](?![โ€โ€™"\'])))', r'\g\n', document) # ่‹ฑๆ–‡ๅ•ๅญ—็ฌฆๆ–ญๅฅ็ฌฆ + document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # ็‰นๆฎŠๅผ•ๅท + else: + document = re.sub('(?P([ใ€‚๏ผŸ๏ผโ€ฆ.?!](?![โ€โ€™"\'])))', r'\g\n', document) # ๅ•ๅญ—็ฌฆๆ–ญๅฅ็ฌฆ + + document = re.sub('(?P(([ใ€‚๏ผŸ๏ผ.!?]|โ€ฆ{1,2})[โ€โ€™"\']))', r'\g\n', + document) # ็‰นๆฎŠๅผ•ๅท + + sent_list_ori = document.splitlines() + for sent in sent_list_ori: + sent = sent.strip() + if not sent: + continue + elif len(sent) <= 2: + continue + else: + while len(sent) > limit: + temp = sent[0:limit] + sent_list.append(temp) + sent = sent[limit:] + sent_list.append(sent) + except: + sent_list.clear() + sent_list.append(document) + return sent_list + + +def get_sent(output_path, + input_path, + fin_list=[], host=-1, seq_len=512) -> None: + + workers = 32 + + if input_path[-1] == '/': + input_path = input_path[:-1] + + cur_path = os.path.join(output_path, str(host) + '.txt') + new_split_sentence = functools.partial(split_sentence, limit=seq_len-2) + with open(cur_path, 'w', encoding='utf-8') as f: + for fi, fin_path in enumerate(fin_list): + if not os.path.exists(os.path.join(input_path, fin_path[0])): + continue + if '.json' not in fin_path[0]: + continue + + print("Processing ", fin_path[0], " ", fi) + + with open(os.path.join(input_path, fin_path[0]), 'r') as fin: + f_data = [l['content'] for l in json.load(fin)] + + pool = multiprocessing.Pool(workers) + all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) + pool.close() + print('finished..') + + cnt = 0 + for d in tqdm(all_sent): + for i in d: + f.write(i.strip() + '\n') + f.write(']]' + '\n') + cnt += 1 + # if cnt >= 2: + # exit() + + +def getFileSize(filepath, shard): + all_data = [] + for i in os.listdir(filepath): + all_data.append(os.path.join(filepath, i)) + all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) + ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = sorted(ans, key=lambda x: x[1], reverse=True) + per_size = all_size / shard + real_shard = [] + temp = [] + accu_size = 0 + for i in ans: + accu_size += i[1] + temp.append(i) + if accu_size > per_size: + real_shard.append(temp) + accu_size = 0 + temp = [] + + if len(temp) > 0: + real_shard.append(temp) + + return real_shard + + +def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): + import socket + host = int(socket.gethostname().split(server_name)[-1]) + + fin_list = real_shard[server_num * base + host - 1] + print(fin_list) + print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + return fin_list, host + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') + parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + args = parser.parse_args() + + server_num = args.server_num + seq_len = args.seq_len + shard = args.shard + input_path = args.input_path + output_path = args.output_path + + real_shard = getFileSize(input_path, shard) + + start = time.time() + for index, shard in enumerate(real_shard): + get_sent(output_path, + input_path, + fin_list=shard, + host=index, + seq_len=seq_len) + print(f'cost {str(time.time() - start)}') + + # if you have multiple server, you can use code below or modify code to openmpi + + # for i in range(len(real_shard) // server_num + 1): + # fin_list, host = get_start_end(real_shard, i) + + # start = time.time() + # get_sent(output_path, + # input_path, + # fin_list=fin_list, host= 10 * i + host - 1) + + # print(f'cost {str(time.time() - start)}') diff --git a/examples/language/roberta/preprocessing/tokenize_mask.py b/examples/language/roberta/preprocessing/tokenize_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..b33871d5d0376cafa152f2a50050eada7ed89f78 --- /dev/null +++ b/examples/language/roberta/preprocessing/tokenize_mask.py @@ -0,0 +1,275 @@ +import time +import os +import psutil +import h5py +import socket +import argparse +import numpy as np +import multiprocessing +from tqdm import tqdm +from random import shuffle +from transformers import AutoTokenizer +from get_mask import PreTrainingDataset + + +def get_raw_instance(document, max_sequence_length=512): + + """ + ่Žทๅ–ๅˆๆญฅ็š„่ฎญ็ปƒๅฎžไพ‹๏ผŒๅฐ†ๆ•ดๆฎตๆŒ‰็…งmax_sequence_lengthๅˆ‡ๅˆ†ๆˆๅคšไธช้ƒจๅˆ†,ๅนถไปฅๅคšไธชๅค„็†ๅฅฝ็š„ๅฎžไพ‹็š„ๅฝขๅผ่ฟ”ๅ›žใ€‚ + :param document: ไธ€ๆ•ดๆฎต + :param max_sequence_length: + :return: a list. each element is a sequence of text + """ + # document = self.documents[index] + max_sequence_length_allowed = max_sequence_length - 2 + # document = [seq for seq in document if len(seq)= max_sequence_length_allowed: + if len(curr_seq) > 0: + result_list.append(curr_seq) + curr_seq = [] + result_list.append(document[sz_idx][ : max_sequence_length_allowed]) + sz_idx += 1 + else: + result_list.append(curr_seq) + curr_seq = [] + # ๅฏนๆœ€ๅŽไธ€ไธชๅบๅˆ—่ฟ›่กŒๅค„็†๏ผŒๅฆ‚ๆžœๅคช็Ÿญ็š„่ฏ๏ผŒไธขๅผƒๆމใ€‚ + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + result_list.append(curr_seq) + + # # ่ฎก็ฎ—ๆ€ปๅ…ฑๅฏไปฅๅพ—ๅˆฐๅคšๅฐ‘ไปฝ + # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 + # print("num_instance:",num_instance) + # # ๅˆ‡ๅˆ†ๆˆๅคšไปฝ๏ผŒๆทปๅŠ ๅˆฐๅˆ—่กจไธญ + # result_list=[] + # for j in range(num_instance): + # index=j*max_sequence_length_allowed + # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1 + # result_list.append(big_list[index:end_index]) + return result_list + + +def split_numpy_chunk(path, tokenizer, pretrain_data, host): + + documents = [] + instances = [] + + s = time.time() + with open(path, encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + # document = line + # if len(document.split("")) <= 3: + # continue + if len(line + ) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print('read_file ', time.time() - s) + + # documents = [x for x in documents if x] + # print(len(documents)) + # print(len(documents[0])) + # print(documents[0][0:10]) + from typing import List + import multiprocessing + + ans = [] + for docs in tqdm(documents): + ans.append(pretrain_data.tokenize(docs)) + print(time.time() - s) + del documents + + instances = [] + for a in tqdm(ans): + raw_ins = get_raw_instance(a) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + sen_num = len(instances) + seq_len = 512 + input_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sen_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32) + + for index, ins in tqdm(enumerate(instances)): + mask_dict = pretrain_data.create_training_instance(ins) + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + + with h5py.File(f'/output/{host}.h5', 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_ids) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +def split_numpy_chunk_pool(input_path, + output_path, + pretrain_data, + worker, + dupe_factor, + seq_len, + file_name): + + if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): + print(f'{file_name}.h5 exists') + return + + documents = [] + instances = [] + + s = time.time() + with open(input_path, 'r', encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + if len(line + ) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print(f'read_file cost {time.time() - s}, length is {len(documents)}') + + ans = [] + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + ans.append(res) + pool.close() + print((time.time() - s) / 60) + del documents + + instances = [] + for a in tqdm(ans, colour='MAGENTA'): + raw_ins = get_raw_instance(a, max_sequence_length=seq_len) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + new_instances = [] + for _ in range(dupe_factor): + for ins in instances: + new_instances.append(ins) + + shuffle(new_instances) + instances = new_instances + print('after dupe_factor, len instance', len(instances)) + + sentence_num = len(instances) + input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32) + + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + pool.close() + print((time.time() - s) / 60) + + with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_mask) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') + parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') + parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') + parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') + parser.add_argument('--worker', type=int, default=32, help='number of process') + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + pretrain_data = PreTrainingDataset(tokenizer, + args.seq_len, + args.backend, + max_predictions_per_seq=args.max_predictions_per_seq) + + + data_len = len(os.listdir(args.input_path)) + + for i in range(data_len): + input_path = os.path.join(args.input_path, f'{i}.txt') + if os.path.exists(input_path): + start = time.time() + print(f'process {input_path}') + split_numpy_chunk_pool(input_path, + args.output_path, + pretrain_data, + args.worker, + args.dupe_factor, + args.seq_len, + i) + end_ = time.time() + print(u'memory๏ผš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + print(f'has cost {(end_ - start) / 60}') + print('-' * 100) + print('') + + # if you have multiple server, you can use code below or modify code to openmpi + + # host = int(socket.gethostname().split('GPU')[-1]) + # for i in range(data_len // args.server_num + 1): + # h = args.server_num * i + host - 1 + # input_path = os.path.join(args.input_path, f'{h}.txt') + # if os.path.exists(input_path): + # start = time.time() + # print(f'I am server {host}, process {input_path}') + # split_numpy_chunk_pool(input_path, + # args.output_path, + # pretrain_data, + # args.worker, + # args.dupe_factor, + # args.seq_len, + # h) + # end_ = time.time() + # print(u'memory๏ผš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + # print(f'has cost {(end_ - start) / 60}') + # print('-' * 100) + # print('') + + diff --git a/examples/language/roberta/pretraining/README.md b/examples/language/roberta/pretraining/README.md new file mode 100644 index 0000000000000000000000000000000000000000..055d6969654d271b618c9cba1366d5dc1a0ef908 --- /dev/null +++ b/examples/language/roberta/pretraining/README.md @@ -0,0 +1,24 @@ +# Pretraining +1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.** + +```bash +bash run_pretrain.sh +``` +* `--hostfile`: servers' host name from /etc/hosts +* `--include`: servers which will be used +* `--nproc_per_node`: number of process(GPU) from each server +* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5 +* `--eval_data_path_prefix`: absolute location of eval data +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json +* `--bert_config`: config.json which represent model +* `--mlm`: model type of backbone, bert or deberta_v2 + +2. if resume training from earylier checkpoint, run the script below. + +```shell +bash run_pretrain_resume.sh +``` +* `--resume_train`: whether to resume training +* `--load_pretrain_model`: absolute path which contains model checkpoint +* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint + diff --git a/examples/language/roberta/pretraining/arguments.py b/examples/language/roberta/pretraining/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9370e00b0c33724808efe9ff4cff8d521db050 --- /dev/null +++ b/examples/language/roberta/pretraining/arguments.py @@ -0,0 +1,152 @@ +import colossalai +from numpy import require + +__all__ = ['parse_args'] + + +def parse_args(): + parser = colossalai.get_default_parser() + + parser.add_argument( + '--lr', + type=float, + required=True, + help='initial learning rate') + parser.add_argument( + '--epoch', + type=int, + required=True, + help='number of epoch') + parser.add_argument( + '--data_path_prefix', + type=str, + required=True, + help="location of the train data corpus") + parser.add_argument( + '--eval_data_path_prefix', + type=str, + required=True, + help='location of the evaluation data corpus') + parser.add_argument( + '--tokenizer_path', + type=str, + required=True, + help='location of the tokenizer') + parser.add_argument( + '--max_seq_length', + type=int, + default=512, + help='sequence length') + parser.add_argument( + '--refresh_bucket_size', + type=int, + default=1, + help= + "This param makes sure that a certain task is repeated for this time steps to \ + optimise on the back propogation speed with APEX's DistributedDataParallel") + parser.add_argument( + "--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help= + "The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument( + "--gradient_accumulation_steps", + default=1, + type=int, + help="accumulation_steps") + parser.add_argument( + "--train_micro_batch_size_per_gpu", + default=2, + type=int, + required=True, + help="train batch size") + parser.add_argument( + "--eval_micro_batch_size_per_gpu", + default=2, + type=int, + required=True, + help="eval batch size") + parser.add_argument( + "--num_workers", + default=8, + type=int, + help="") + parser.add_argument( + "--async_worker", + action='store_true', + help="") + parser.add_argument( + "--bert_config", + required=True, + type=str, + help="location of config.json") + parser.add_argument( + "--wandb", + action='store_true', + help="use wandb to watch model") + parser.add_argument( + "--wandb_project_name", + default='roberta', + help="wandb project name") + parser.add_argument( + "--log_interval", + default=100, + type=int, + help="report interval") + parser.add_argument( + "--log_path", + type=str, + required=True, + help="log file which records train step") + parser.add_argument( + "--tensorboard_path", + type=str, + required=True, + help="location of tensorboard file") + parser.add_argument( + "--colossal_config", + type=str, + required=True, + help="colossal config, which contains zero config and so on") + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="location of saving checkpoint, which contains model and optimizer") + parser.add_argument( + '--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument( + '--vscode_debug', + action='store_true', + help="use vscode to debug") + parser.add_argument( + '--load_pretrain_model', + default='', + type=str, + help="location of model's checkpoin") + parser.add_argument( + '--load_optimizer_lr', + default='', + type=str, + help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step") + parser.add_argument( + '--resume_train', + action='store_true', + help="whether resume training from a early checkpoint") + parser.add_argument( + '--mlm', + default='bert', + type=str, + help="model type, bert or deberta") + parser.add_argument( + '--checkpoint_activations', + action='store_true', + help="whether to use gradient checkpointing") + + args = parser.parse_args() + return args diff --git a/examples/language/roberta/pretraining/bert_dataset_provider.py b/examples/language/roberta/pretraining/bert_dataset_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8cf2a910e9e41275edc480c25e86fa166fd0ee --- /dev/null +++ b/examples/language/roberta/pretraining/bert_dataset_provider.py @@ -0,0 +1,15 @@ +class BertDatasetProviderInterface: + def get_shard(self, index, shuffle=True): + raise NotImplementedError + + def release_shard(self, index): + raise NotImplementedError + + def prefetch_shard(self, index): + raise NotImplementedError + + def get_batch(self, batch_iter): + raise NotImplementedError + + def prefetch_batch(self): + raise NotImplementedError diff --git a/examples/language/roberta/pretraining/evaluation.py b/examples/language/roberta/pretraining/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..83f94082f6c0773e5de4a8c8daf14971e76d2426 --- /dev/null +++ b/examples/language/roberta/pretraining/evaluation.py @@ -0,0 +1,71 @@ +import os +import math +import torch +from tqdm import tqdm +from utils.global_vars import get_timers, get_tensorboard_writer +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider + +def evaluate(engine, args, logger, global_step): + evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) + start_shard = 0 + + engine.eval() + timers = get_timers() + eval_step = 0 + eval_loss = 0 + cur_loss = 0 + world_size = torch.distributed.get_world_size() + + with torch.no_grad(): + + for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): + + timers('eval_shard_time').start() + + dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) + # evaluate_dataset_provider.prefetch_shard(shard + 1) + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + eval_step += 1 + input_ids = batch_data[0].cuda() + attention_mask = batch_data[1].cuda() + token_type_ids = batch_data[2].cuda() + mlm_label = batch_data[3].cuda() + # nsp_label = batch_data[5].cuda() + + output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = engine.criterion(output.logits, mlm_label)#prediction_scores + evaluate_dataset_provider.prefetch_batch() + + eval_loss += loss.float().item() + + cur_loss = eval_loss / eval_step + elapsed_time = timers("eval_shard_time").elapsed() + elapsed_time_per_iteration = elapsed_time / eval_step + ppl = math.exp(cur_loss) + + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'loss': cur_loss, + 'ppl': ppl, + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + + logger.info(eval_log_str) + logger.info('-' * 100) + logger.info('') + + evaluate_dataset_provider.release_shard() + engine.train() + return cur_loss diff --git a/examples/language/roberta/pretraining/hostfile b/examples/language/roberta/pretraining/hostfile new file mode 100644 index 0000000000000000000000000000000000000000..f4e047f01fdd6e2826e5cefd2bf84c8178d561ad --- /dev/null +++ b/examples/language/roberta/pretraining/hostfile @@ -0,0 +1,10 @@ +GPU001 +GPU002 +GPU003 +GPU004 +GPU005 +GPU006 +GPU007 +GPU008 +GPU009 +GPU010 diff --git a/examples/language/roberta/pretraining/loss.py b/examples/language/roberta/pretraining/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4f872a755d233ad6a62c3effd4261ca45f809b --- /dev/null +++ b/examples/language/roberta/pretraining/loss.py @@ -0,0 +1,17 @@ +import torch + +__all__ = ['LossForPretraining'] + + +class LossForPretraining(torch.nn.Module): + + def __init__(self, vocab_size): + super(LossForPretraining, self).__init__() + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) + self.vocab_size = vocab_size + + def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): + masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) + # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) + total_loss = masked_lm_loss #+ next_sentence_loss + return total_loss diff --git a/examples/language/roberta/pretraining/model/bert.py b/examples/language/roberta/pretraining/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..67c85f760776c49d18c7fc94bf65f22eb789de08 --- /dev/null +++ b/examples/language/roberta/pretraining/model/bert.py @@ -0,0 +1,1893 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + + +import math +import os +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bert-base-uncased" +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +# TokenClassification docstring +_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" +_TOKEN_CLASS_EXPECTED_OUTPUT = ( + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) +_TOKEN_CLASS_EXPECTED_LOSS = 0.01 + +# QuestionAnswering docstring +_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2" +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity" +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" +_SEQ_CLASS_EXPECTED_LOSS = 0.01 + + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BertAttention(config, position_embedding_type="absolute") + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of [`BertForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForPreTraining.from_pretrained("bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'paris'", + expected_loss=0.88, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, + expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_QA, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/language/roberta/pretraining/model/deberta_v2.py b/examples/language/roberta/pretraining/model/deberta_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ce82847f75202d10f21493e37ee2021925ec22 --- /dev/null +++ b/examples/language/roberta/pretraining/model/deberta_v2.py @@ -0,0 +1,1631 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model.""" + +import math +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import softmax_backward_data +from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config +from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" + +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", +] + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Byte"], + ) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + # rel = nn.Parameter(torch.empty((pos_ebd_size, config.hidden_size))) + # self.rel_embeddings = nn.init.normal_(rel, mean=0.0, std=config.initializer_range) + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + att_span = self.position_buckets + rel_index = torch.arange(0, att_span * 2).long().to(self.rel_embeddings.weight.device) + rel_embeddings = self.rel_embeddings(rel_index) + # rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + # rel_embeddings = self.rel_embeddings if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + output_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + # self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.ByteTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, optional): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, optional): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + # rel_index = torch.arange(0, att_span * 2).long().to(query_layer.device) + # rel_embeddings = rel_embeddings(rel_index).unsqueeze(0) + rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type: + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DebertaV2Encoder): + module.gradient_checkpointing = value + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior.``` + + + Parameters: + config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, 1) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.deberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..cce83691350553dc12039f8f6ede7ddb1e1daf81 --- /dev/null +++ b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -0,0 +1,182 @@ +import os +import random +import h5py +import logging +import json +import time +from concurrent.futures import ProcessPoolExecutor + +import numpy as np + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.sampler import RandomSampler +from torch.utils.data.distributed import DistributedSampler + +from bert_dataset_provider import BertDatasetProviderInterface +import colossalai.utils as utils + +# Workaround because python functions are not picklable +class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +def create_pretraining_dataset(input_file, max_predictions_per_seq, + num_workers, train_batch_size, worker_init, + data_sampler): + train_data = pretraining_dataset( + input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) + train_dataloader = DataLoader(train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True + ) + return train_dataloader, len(train_data) + + +class pretraining_dataset(Dataset): + def __init__(self, input_file, max_predictions_per_seq): + self.input_file = input_file + self.max_predictions_per_seq = max_predictions_per_seq + f = h5py.File(input_file, "r") + keys = [ + 'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions' + ] + self.inputs = [np.asarray(f[key][:]) for key in keys] + f.close() + + def __len__(self): + 'Denotes the total number of samples' + return len(self.inputs[0]) + + def __getitem__(self, index): + + [ + input_ids, input_mask, segment_ids, masked_lm_labels + ] = [ + torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else + torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) + ] + + return [ + input_ids, input_mask, + segment_ids, masked_lm_labels + ] + + +class NvidiaBertDatasetProvider(BertDatasetProviderInterface): + def __init__(self, args, evaluate=False): + self.num_workers = args.num_workers + self.max_seq_length = args.max_seq_length + self.max_predictions_per_seq = args.max_predictions_per_seq + + self.gradient_accumulation_steps = args.gradient_accumulation_steps + if not evaluate: + self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu + else: + self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu + self.logger = args.logger + + self.global_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + # Initialize dataset files + if not evaluate: + self.dataset_files = [ + os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if + os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + ] + else: + self.dataset_files = [ + os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if + os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + ] + + self.dataset_files.sort() + # random.shuffle(self.dataset_files) + self.num_files = len(self.dataset_files) + # self.data_sampler = RandomSampler + self.data_sampler = DistributedSampler + + self.worker_init = WorkerInitObj(args.seed + args.local_rank) + self.dataset_future = None + self.pool = ProcessPoolExecutor(1) + self.data_file = None + self.shuffle = True + + if self.global_rank == 0: + self.logger.info( + f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}" + ) + + def get_shard(self, index): + start = time.time() + if self.dataset_future is None: + self.data_file = self._get_shard_file(index) + self.train_dataloader, sample_count = create_pretraining_dataset( + input_file=self.data_file, + max_predictions_per_seq=self.max_predictions_per_seq, + num_workers=self.num_workers, + train_batch_size=self.train_micro_batch_size_per_gpu, + worker_init=self.worker_init, + data_sampler=self.data_sampler) + else: + self.train_dataloader, sample_count = self.dataset_future.result( + timeout=None) + + self.logger.info( + f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." + ) + + return self.train_dataloader, sample_count + + def release_shard(self): + del self.train_dataloader + self.pool.shutdown() + + def prefetch_shard(self, index): + self.data_file = self._get_shard_file(index) + self.dataset_future = self.pool.submit( + create_pretraining_dataset, self.data_file, + self.max_predictions_per_seq, self.num_workers, + self.train_micro_batch_size_per_gpu, self.worker_init, + self.data_sampler) + + def get_batch(self, batch_iter): + return batch_iter + + def prefetch_batch(self): + pass + + def _get_shard_file(self, shard_index): + file_index = self._get_shard_file_index(shard_index, self.global_rank) + return self.dataset_files[file_index] + + def _get_shard_file_index(self, shard_index, global_rank): + # if dist.is_initialized() and self.world_size > self.num_files: + # remainder = self.world_size % self.num_files + # file_index = (shard_index * self.world_size) + global_rank + ( + # remainder * shard_index) + # else: + # file_index = shard_index * self.world_size + global_rank + + return shard_index % self.num_files + + def shuffle_dataset(self, epoch): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.num_files, generator=g).tolist() + new_dataset = [self.dataset_files[i] for i in indices] + self.dataset_files = new_dataset + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/pretrain_utils.py b/examples/language/roberta/pretraining/pretrain_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba17b0f5ee09cc4a6fbcf1615d91e7ef289fd778 --- /dev/null +++ b/examples/language/roberta/pretraining/pretrain_utils.py @@ -0,0 +1,112 @@ +import transformers +import logging +from colossalai.nn.lr_scheduler import LinearWarmupLR +from transformers import get_linear_schedule_with_warmup +from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig +from transformers import GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, AutoModelForMaskedLM +from colossalai.nn.optimizer import FusedAdam +from torch.optim import AdamW +from colossalai.core import global_context as gpc +import torch +import os +import sys +sys.path.append(os.getcwd()) +from model.deberta_v2 import DebertaV2ForMaskedLM +from model.bert import BertForMaskedLM +import torch.nn as nn + +from collections import OrderedDict + +__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] + + +def get_new_state_dict(state_dict, start_index=13): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[start_index:] + new_state_dict[name] = v + return new_state_dict + + +class LMModel(nn.Module): + def __init__(self, model, config, args): + super().__init__() + + self.checkpoint = args.checkpoint_activations + self.config = config + self.model = model + if self.checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None): + # Only return lm_logits + return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +def get_model(args, logger): + + if args.mlm == 'bert': + config = transformers.BertConfig.from_json_file(args.bert_config) + model = BertForMaskedLM(config) + elif args.mlm == 'deberta_v2': + config = transformers.DebertaV2Config.from_json_file(args.bert_config) + model = DebertaV2ForMaskedLM(config) + else: + raise Exception("Invalid mlm!") + + if len(args.load_pretrain_model) > 0: + assert os.path.exists(args.load_pretrain_model) + # load_checkpoint(args.load_pretrain_model, model, strict=False) + m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + # new_state_dict = get_new_state_dict(m_state_dict) + model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!! + logger.info("load model success") + + numel = sum([p.numel() for p in model.parameters()]) + if args.checkpoint_activations: + model.gradient_checkpointing_enable() + # model = LMModel(model, config, args) + + return config, model, numel + + +def get_optimizer(model, lr): + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + + # configure the weight decay for bert models + optimizer_grouped_parameters = [{ + 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + 'weight_decay': 0.1 + }, { + 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0 + }] + optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) + return optimizer + + +def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): + # warmup_steps = int(total_steps * warmup_ratio) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch) + # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) + return lr_scheduler + + +def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): + model_path = path + '_pytorch_model.bin' + optimizer_lr_path = path + '.op_lrs' + checkpoint = {} + checkpoint['optimizer'] = optimizer.state_dict() + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + checkpoint['epoch'] = epoch + checkpoint['shard'] = shard + checkpoint['global_step'] = global_step + model_state = model.state_dict() #each process must run model.state_dict() + if gpc.get_global_rank() == 0: + torch.save(checkpoint, optimizer_lr_path) + torch.save(model_state, model_path) + + + diff --git a/examples/language/roberta/pretraining/run_pretrain.sh b/examples/language/roberta/pretraining/run_pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..144cd0ab96fd22771889b9eebd6eecca0f939205 --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretrain.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + +colossal_config="$root_path/../configs/colossalai_ddp.py" + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --colossal_config $colossal_config \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretrain_resume.sh b/examples/language/roberta/pretraining/run_pretrain_resume.sh new file mode 100644 index 0000000000000000000000000000000000000000..a0704cf7c517f6e784a4f80483bfb5db3b0c5c94 --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretrain_resume.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + +colossal_config="$root_path/../configs/colossalai_ddp.py" + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --colossal_config $colossal_config \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + --resume_train \ + --load_pretrain_model /ckpt/1.pt \ + --load_optimizer_lr /ckpt/1.op_lrs \ + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..9840a122cbc4300d85c2542568aeaaa1ae29d07e --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -0,0 +1,226 @@ +import colossalai +import math +import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +import colossalai.nn as col_nn +from arguments import parse_args +from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt +from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args +from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer +from utils.logger import Logger +from evaluation import evaluate +from loss import LossForPretraining + +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from tqdm import tqdm +import os +import time +from functools import partial + +from transformers import AutoTokenizer + +from colossalai.gemini import ChunkManager, GeminiManager +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.utils import get_current_device +from colossalai.nn.parallel import ZeroDDP +from colossalai.zero import ZeroOptimizer +from colossalai.tensor import ProcessGroup +from colossalai.nn.optimizer import HybridAdam + + +def main(): + + args = parse_args() + launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) + + if args.vscode_debug: + colossalai.launch(config={}, + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) + args.local_rank = -1 + args.log_interval = 1 + else: + colossalai.launch_from_torch(args.colossal_config) #args.colossal_config + args.local_rank = int(os.environ["LOCAL_RANK"]) + logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') + + log_args(logger, args) + args.tokenizer = tokenizer + args.logger = logger + set_global_variables(launch_time, args.tensorboard_path) + + use_zero = hasattr(gpc.config, 'zero') + world_size = torch.distributed.get_world_size() + + # build model, optimizer and criterion + if use_zero: + shard_strategy = TensorShardStrategy() + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, + shard_param=True): + + config, model, numel = get_model(args, logger) + # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) + else: + config, model, numel = get_model(args, logger) + logger.info("no_zero") + if torch.distributed.get_rank() == 0: + os.mkdir(os.path.join(args.ckpt_path, launch_time)) + + logger.info(f'Model numel: {numel}') + + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + total_steps = steps_per_epoch * args.epoch + + # build optimizer and lr_scheduler + + start_epoch = 0 + start_shard = 0 + global_step = 0 + if args.resume_train: + assert os.path.exists(args.load_optimizer_lr) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') + o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 + optimizer = get_optimizer(model, lr=args.lr) + optimizer.load_state_dict(o_l_state_dict['optimizer']) + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") + # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() + lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + + start_epoch = o_l_state_dict['epoch'] + start_shard = o_l_state_dict['shard'] + 1 + # global_step = o_l_state_dict['global_step'] + 1 + logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') + else: + optimizer = get_optimizer(model, lr=args.lr) + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) + + # optimizer = gpc.config.optimizer.pop('type')( + # model.parameters(), **gpc.config.optimizer) + # optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5) + criterion = LossForPretraining(config.vocab_size) + + # build dataloader + pretrain_dataset_provider = NvidiaBertDatasetProvider(args) + + # initialize with colossalai + engine, _, _, lr_scheduelr = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + + logger.info(get_mem_info(prefix='After init model, ')) + + + best_loss = None + eval_loss = 0 + train_loss = 0 + timers = get_timers() + timers('interval_time').start() + timers('epoch_time').start() + timers('shard_time').start() + + for epoch in range(start_epoch, args.epoch): + + for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): + + dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) + # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + engine.train() + + for step, batch_data in iterator_data: + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") + attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") + token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}") + mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") + # nsp_label = batch_data[5].cuda() + + output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = engine.criterion(output.logits, mlm_label) + pretrain_dataset_provider.prefetch_batch() + + engine.backward(loss) + train_loss += loss.float().item() + # if (step + 1) % args.accumulation_step == 0: + engine.step() + lr_scheduelr.step() + engine.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0 and global_step != 0 \ + and torch.distributed.get_rank() == 0: + elapsed_time = timers('interval_time').elapsed(reset=False) + elapsed_time_per_iteration = elapsed_time / global_step + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) + + cur_loss = train_loss / args.log_interval + current_lr = lr_scheduelr.get_last_lr()[0] + log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + logger.info(log_str, print_=False) + + if args.wandb: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_train({ + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + train_loss = 0 + + logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') + logger.info('*' * 100) + + eval_loss += evaluate(engine, args, logger, global_step) + save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) + + + eval_loss /= len(os.listdir(args.data_path_prefix)) + logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') + logger.info('-' * 100) + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'all_eval_shard_loss': eval_loss, + }, epoch) + start_shard = 0 + eval_loss = 0 + + pretrain_dataset_provider.release_shard() + + logger.info('Congratulation, training has finished!!!') + + +if __name__ == '__main__': + main() diff --git a/examples/language/roberta/pretraining/utils/WandbLog.py b/examples/language/roberta/pretraining/utils/WandbLog.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd28a98186ba684eea28360f14d8060b5ccdd51 --- /dev/null +++ b/examples/language/roberta/pretraining/utils/WandbLog.py @@ -0,0 +1,46 @@ +import time +import wandb +import os +from torch.utils.tensorboard import SummaryWriter + +class WandbLog: + + @classmethod + def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + wandb.init(project=project, notes=notes, name=name, config=config) + + @classmethod + def log(cls, result, model=None, gradient=None): + wandb.log(result) + + if model: + wandb.watch(model) + + if gradient: + wandb.watch(gradient) + + +class TensorboardLog: + + def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + if not os.path.exists(location): + os.mkdir(location) + self.writer = SummaryWriter(location, comment=name) + + def log_train(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/train', v, step) + + def log_eval(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/eval', v, step) + + def log_zeroshot(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}_acc/eval', v, step) + + + + + + diff --git a/examples/language/roberta/pretraining/utils/exp_util.py b/examples/language/roberta/pretraining/utils/exp_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a02b0872acbcf6ac799bb3a589bf35cb050b1071 --- /dev/null +++ b/examples/language/roberta/pretraining/utils/exp_util.py @@ -0,0 +1,99 @@ +import functools +import os, shutil +import torch +import psutil +from colossalai.core import global_context as gpc + +def logging(s, log_path, print_=True, log_=True): + if print_: + print(s) + if log_: + with open(log_path, 'a+') as f_log: + f_log.write(s + '\n') + +def get_logger(log_path, **kwargs): + return functools.partial(logging, log_path=log_path, **kwargs) + +def create_exp_dir(dir_path, scripts_to_save=None, debug=False): + if debug: + print('Debug Mode : no experiment dir created') + return functools.partial(logging, log_path=None, log_=False) + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + print('Experiment dir : {}'.format(dir_path)) + if scripts_to_save is not None: + script_path = os.path.join(dir_path, 'scripts') + if not os.path.exists(script_path): + os.makedirs(script_path) + for script in scripts_to_save: + dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + + return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_parameters_in_billions(model, world_size=1): + gpus_per_model = world_size + + approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) + for model_module in model]) + + return approx_parameters_in_billions * gpus_per_model / (1e9) + +def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): + gpus_per_model = 1 + batch_size = args.train_micro_batch_size_per_gpu + samples_per_model = batch_size * args.max_seq_length + model_replica_count = world_size / gpus_per_model + approx_parameters_in_billions = numel + elapsed_time_per_iter = iteration_time / total_iterations + samples_per_second = batch_size / elapsed_time_per_iter + + #flops calculator + hidden_size = config.hidden_size + num_layers = config.num_hidden_layers + vocab_size = config.vocab_size + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3. + checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) + return samples_per_second, tflops, approx_parameters_in_billions + +def synchronize(): + if not torch.distributed.is_available(): + return + if not torch.distributed.is_intialized(): + return + world_size = torch.distributed.get_world_size() + if world_size == 1: + return + torch.distributed.barrier() + +def log_args(logger, args): + logger.info('--------args----------') + message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) + message += '\n' + message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info(message) + logger.info('--------args----------\n') \ No newline at end of file diff --git a/examples/language/roberta/pretraining/utils/global_vars.py b/examples/language/roberta/pretraining/utils/global_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..363cbf91c06563724f8abddf6610e8857694faf9 --- /dev/null +++ b/examples/language/roberta/pretraining/utils/global_vars.py @@ -0,0 +1,126 @@ +import time +import torch +from .WandbLog import TensorboardLog + +_GLOBAL_TIMERS = None +_GLOBAL_TENSORBOARD_WRITER = None + + +def set_global_variables(launch_time, tensorboard_path): + _set_timers() + _set_tensorboard_writer(launch_time, tensorboard_path) + +def _set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers() + +def _set_tensorboard_writer(launch_time, tensorboard_path): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, + 'tensorboard writer') + if torch.distributed.get_rank() == 0: + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + # assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=False): + """Write timers to a tensorboard writer""" + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar(name + '-time', value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1): + print(string, flush=True) + else: + print(string, flush=True) diff --git a/examples/language/roberta/pretraining/utils/logger.py b/examples/language/roberta/pretraining/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..481c4c6ce61bd6a8bade07ca9bbba57803aae119 --- /dev/null +++ b/examples/language/roberta/pretraining/utils/logger.py @@ -0,0 +1,31 @@ +import os +import logging +import torch.distributed as dist + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Logger(): + def __init__(self, log_path, cuda=False, debug=False): + self.logger = logging.getLogger(__name__) + self.cuda = cuda + self.log_path = log_path + self.debug = debug + + + def info(self, message, log_=True, print_=True, *args, **kwargs): + if (self.cuda and dist.get_rank() == 0) or not self.cuda: + if print_: + self.logger.info(message, *args, **kwargs) + + if log_: + with open(self.log_path, 'a+') as f_log: + f_log.write(message + '\n') + + + def error(self, message, *args, **kwargs): + self.logger.error(message, *args, **kwargs) diff --git a/examples/tutorial/.gitignore b/examples/tutorial/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f873b6a4abbdae50cab83899b5b4a7554a5266e3 --- /dev/null +++ b/examples/tutorial/.gitignore @@ -0,0 +1 @@ +./data/ diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bef7c8905033ea47ec25626160498744f00d0ee2 --- /dev/null +++ b/examples/tutorial/README.md @@ -0,0 +1,193 @@ +# Colossal-AI Tutorial Hands-on + +## Introduction + +Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc. + + +[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates +many advanced technologies such as multi-dimensional tensor parallelism, sequence parallelism, heterogeneous memory management, +large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and +quickly deploy large AI model training and inference, reducing large AI model training budgets and scaling down the labor cost of learning and deployment. + +### ๐Ÿš€ Quick Links + +[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) | +[**Paper**](https://arxiv.org/abs/2110.14883) | +[**Documentation**](https://www.colossalai.org/) | +[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) | +[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + +## Table of Content + + - Multi-dimensional Parallelism + - Know the components and sketch of Colossal-AI + - Step-by-step from PyTorch to Colossal-AI + - Try data/pipeline parallelism and 1D/2D/2.5D/3D tensor parallelism using a unified model + - Sequence Parallelism + - Try sequence parallelism with BERT + - Combination of data/pipeline/sequence parallelism + - Faster training and longer sequence length + - Large Batch Training Optimization + - Comparison of small/large batch size with SGD/LARS optimizer + - Acceleration from a larger batch size + - Auto-Parallelism + - Parallelism with normal non-distributed training code + - Model tracing + solution solving + runtime communication inserting all in one auto-parallelism system + - Try single program, multiple data (SPMD) parallel with auto-parallelism SPMD solver on ResNet50 + - Fine-tuning and Serving for OPT + - Try pre-trained OPT model weights with Colossal-AI + - Fine-tuning OPT with limited hardware using ZeRO, Gemini and parallelism + - Deploy the fine-tuned model to inference service + - Acceleration of Stable Diffusion + - Stable Diffusion with Lightning + - Try Lightning Colossal-AI strategy to optimize memory and accelerate speed + + +## Discussion + +Discussion about the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project is always welcomed! We would love to exchange ideas with the community to better help this project grow. +If you think there is a need to discuss anything, you may jump to our [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w). + +If you encounter any problem while running these tutorials, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository. + +## ๐Ÿ› ๏ธ Setup environment +You should use `conda` to create a virtual environment, we recommend **python 3.8**, e.g. `conda create -n colossal python=3.8`. This installation commands are for CUDA 11.3, if you have a different version of CUDA, please download PyTorch and Colossal-AI accordingly. + +``` +# install torch +# visit https://pytorch.org/get-started/locally/ to download other versions +pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 + +# install latest ColossalAI +# visit https://colossalai.org/download to download corresponding version of Colossal-AI +pip install colossalai==0.1.11rc3+torch1.12cu11.3 -f https://release.colossalai.org +``` + +You can run `colossalai check -i` to verify if you have correctly set up your environment ๐Ÿ•น๏ธ. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/colossalai%20check%20-i.png) + +If you encounter messages like `please install with cuda_ext`, do let me know as it could be a problem of the distribution wheel. ๐Ÿ˜ฅ + +Then clone the Colossal-AI repository from GitHub. +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI/examples/tutorial +``` + +## ๐Ÿ”ฅ Multi-dimensional Hybrid Parallel with Vision Transformer +1. Go to **hybrid_parallel** folder in the **tutorial** directory. +2. Install our model zoo. +```bash +pip install titans +``` +3. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag. +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + +4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs. + +## โ˜€๏ธ Sequence Parallel with BERT +1. Go to the **sequence_parallel** folder in the **tutorial** directory. +2. Run with the following command +```bash +export PYTHONPATH=$PWD +colossalai run --nproc_per_node 4 train.py -s +``` +3. The default config is sequence parallel size = 2, pipeline size = 1, letโ€™s change pipeline size to be 2 and try it again. + +## ๐Ÿ“• Large batch optimization with LARS and LAMB +1. Go to the **large_batch_optimizer** folder in the **tutorial** directory. +2. Run with synthetic data +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + +## ๐Ÿ˜€ Auto-Parallel Tutorial +1. Go to the **auto_parallel** folder in the **tutorial** directory. +2. Install `pulp` and `coin-or-cbc` for the solver. +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` +2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset. +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s +``` + +You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-parallel%20demo.png) + +## ๐ŸŽ† Auto-Checkpoint Tutorial +1. Stay in the `auto_parallel` folder. +2. Install the dependencies. +```bash +pip install matplotlib transformers +``` +3. Run a simple resnet50 benchmark to automatically checkpoint the model. +```bash +python auto_ckpt_solver_test.py --model resnet50 +``` + +You should expect the log to be like this +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20demo.png) + +This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger. +```bash +python auto_ckpt_solver_test.py --model gpt2 +``` + +4. Run a simple benchmark to find the optimal batch size for checkpointed model. +```bash +python auto_ckpt_batchsize_test.py +``` + +You can expect the log to be like +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20batchsize.png) + +## ๐Ÿš€ Run OPT finetuning and inference +1. Install the dependency +```bash +pip install datasets accelerate +``` +2. Run finetuning with synthetic datasets with one GPU +```bash +bash ./run_clm_synthetic.sh +``` +3. Run finetuning with 4 GPUs +```bash +bash ./run_clm_synthetic.sh 16 0 125m 4 +``` +4. Run inference with OPT 125M +```bash +docker hpcaitech/tutorial:opt-inference +docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-inference +``` +5. Start the http server inside the docker container with tensor parallel size 2 +```bash +python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m +``` + +## ๐Ÿ–ผ๏ธ Accelerate Stable Diffusion with Colossal-AI +1. Create a new environment for diffusion +```bash +conda env create -f environment.yaml +conda activate ldm +``` +2. Install Colossal-AI from our official page +```bash +pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +``` +3. Install PyTorch Lightning compatible commit +```bash +git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa +pip install -r requirements.txt && pip install . +cd .. +``` + +4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`. +5. Run training with CIFAR10. +```bash +python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml +``` diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e99a018c2da1935f2de3b212c1a3800b3abd8079 --- /dev/null +++ b/examples/tutorial/auto_parallel/README.md @@ -0,0 +1,106 @@ +# Auto-Parallelism with ResNet + +## ๐Ÿš€Quick Start +### Auto-Parallel Tutorial +1. Install `pulp` and `coin-or-cbc` for the solver. +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` +2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset. +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s +``` + +You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-parallel%20demo.png) + + +### Auto-Checkpoint Tutorial +1. Stay in the `auto_parallel` folder. +2. Install the dependencies. +```bash +pip install matplotlib transformers +``` +3. Run a simple resnet50 benchmark to automatically checkpoint the model. +```bash +python auto_ckpt_solver_test.py --model resnet50 +``` + +You should expect the log to be like this +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20demo.png) + +This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger. +```bash +python auto_ckpt_solver_test.py --model gpt2 +``` + +4. Run a simple benchmark to find the optimal batch size for checkpointed model. +```bash +python auto_ckpt_batchsize_test.py +``` + +You can expect the log to be like +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20batchsize.png) + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`. +The dataset will be downloaded to `colossalai/examples/tutorials/data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + +## extra requirements to use autoparallel + +```bash +pip install pulp +conda install coin-or-cbc +``` + +## Run on 2*2 device mesh + +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py +``` + +## Auto Checkpoint Benchmarking + +We prepare two bechmarks for you to test the performance of auto checkpoint + +The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. + +The second test `auto_ckpt_batchsize_test.py` will show you the advantage of fitting larger batchsize training into limited GPU memory with the help of our activation checkpoint solver (test on ResNet152). It will output the benchmark summary. + +The usage of the above two test +```bash +# run auto_ckpt_solver_test.py on gpt2 medium +python auto_ckpt_solver_test.py --model gpt2 + +# run auto_ckpt_solver_test.py on resnet50 +python auto_ckpt_solver_test.py --model resnet50 + +# tun auto_ckpt_batchsize_test.py +python auto_ckpt_batchsize_test.py +``` + +There are some results for your reference + +## Auto Checkpoint Solver Test + +### ResNet 50 +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/resnet50_benchmark.png) + +### GPT2 Medium +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/gpt2_benchmark.png) + +## Auto Checkpoint Batch Size Test +```bash +===============test summary================ +batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s +batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s +batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s +``` diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5decfc695f6fe5c7c2b0249a5ed42162d32c6bb3 --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -0,0 +1,59 @@ +import time +from argparse import ArgumentParser +from copy import deepcopy +from functools import partial + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import bench, data_gen_resnet + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def _benchmark(rank, world_size, port): + """Auto activation checkpoint batchsize benchmark + This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of + maximum GPU memory, and with the batch size of [512, 1024, 2048], you could see that using auto activation + checkpoint with optimality guarantee, we might be able to find better batch size for the model, as larger batch + size means that we are able to use larger portion of GPU FLOPS, while recomputation scheduling with our solver + only result in minor performance drop. So at last we might be able to find better training batch size for our + model (combine with large batch training optimizer such as LAMB). + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = tm.resnet152() + gm = symbolic_trace(model) + raw_graph = deepcopy(gm.graph) + peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] + for batch_size in batch_sizes: + batch_size = int(batch_size) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) + gm.graph = solver.solve() + peak_mem, step_time = bench(gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=5) + peak_mems.append(peak_mem) + through_puts.append(batch_size / step_time * 1.0e3) + gm.graph = deepcopy(raw_graph) + + # print results + print("===============benchmark summary================") + for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): + print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + + +def auto_activation_checkpoint_batchsize_benchmark(): + world_size = 1 + run_func_module = partial(_benchmark, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + auto_activation_checkpoint_batchsize_benchmark() diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0f2ef661dfe4852f76b6d439fdd8d828f326bb --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -0,0 +1,89 @@ +import time +from argparse import ArgumentParser +from functools import partial + +import matplotlib.pyplot as plt +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def _benchmark(rank, world_size, port, args): + """ + Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50. + The benchmark will sample in a range of memory budget for each model and output the benchmark summary and + data visualization of peak memory vs. budget memory and relative step time vs. peak memory. + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if args.model == 'resnet50': + model = tm.resnet50() + data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) + gm = symbolic_trace(model) + gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta')) + loss = torch.nn.CrossEntropyLoss() + else: + model = gpt2_medium() + data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) + data, mask = data_gen(device='meta')[0] + gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = metainfo_trace(gm, data, mask) + loss = GPTLMLoss() + + free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2 + start_factor = 4 if args.model == 'resnet50' else 10 + + # trace and benchmark + budgets, peak_hist, step_hist = bench_rotor(gm, + loss, + data_gen, + num_steps=5, + sample_points=15, + free_memory=free_memory, + start_factor=start_factor) + + # print summary + print("==============benchmark summary==============") + for budget, peak, step in zip(budgets, peak_hist, step_hist): + print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + + # plot valid results + fig, axs = plt.subplots(1, 2, figsize=(16, 8)) + valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) + + # plot peak memory vs. budget memory + axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].set_xlabel("Budget Memory (MB)") + axs[0].set_ylabel("Peak Memory (MB)") + axs[0].set_title("Peak Memory vs. Budget Memory") + + # plot relative step time vs. budget memory + axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].set_xlabel("Peak Memory (MB)") + axs[1].set_ylabel("Relative Step Time") + axs[1].set_title("Step Time vs. Peak Memory") + axs[1].set_ylim(0.8, 1.5) + + # save plot + fig.savefig(f"{args.model}_benchmark.png") + + +def auto_activation_checkpoint_benchmark(args): + world_size = 1 + run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") + parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50']) + args = parser.parse_args() + + auto_activation_checkpoint_benchmark(args) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e4aff13e484af1769ca19a636ded0191be50a021 --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -0,0 +1,200 @@ +import argparse +import os +from pathlib import Path + +import torch +from titans.utils import barrier_context +from torch.fx import GraphModule +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet50 +from tqdm import tqdm + +import colossalai +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingLR +from colossalai.utils import get_dataloader + +DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10") + return parser.parse_args() + + +def synthesize_data(): + img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) + label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) + return img, label + + +def main(): + args = parse_args() + colossalai.launch_from_torch(config='./config.py') + + logger = get_dist_logger() + + if not args.synthetic: + with barrier_context(): + # build dataloaders + train_dataset = CIFAR10(root=DATA_ROOT, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010]), + ])) + + test_dataset = CIFAR10(root=DATA_ROOT, + train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) + + train_dataloader = get_dataloader( + dataset=train_dataset, + add_sampler=True, + shuffle=True, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + + test_dataloader = get_dataloader( + dataset=test_dataset, + add_sampler=True, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + else: + train_dataloader, test_dataloader = None, None + + # initialize device mesh + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # trace the model with meta data + tracer = ColoTracer() + model = resnet50(num_classes=10).cuda() + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # prepare info for solver + solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + + # solve the solution + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + if gpc.get_global_rank() == 0: + for index, node in enumerate(graph.nodes): + print(node.name, node.strategies_vector[solution[index]].name) + + # process the graph for distributed training ability + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + + # build criterion + criterion = torch.nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + + # lr_scheduler + lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) + + for epoch in range(gpc.config.NUM_EPOCHS): + gm.train() + + if args.synthetic: + # if we use synthetic data + # we assume it only has 30 steps per epoch + num_steps = range(30) + + else: + # we use the actual number of steps for training + num_steps = range(len(train_dataloader)) + data_iter = iter(train_dataloader) + progress = tqdm(num_steps) + + for _ in progress: + if args.synthetic: + # generate fake data + img, label = synthesize_data() + else: + # get the real data + img, label = next(data_iter) + + img = img.cuda() + label = label.cuda() + optimizer.zero_grad() + output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + train_loss = criterion(output, label) + train_loss.backward(train_loss) + optimizer.step() + lr_scheduler.step() + + # run evaluation + gm.eval() + correct = 0 + total = 0 + + if args.synthetic: + # if we use synthetic data + # we assume it only has 10 steps for evaluation + num_steps = range(30) + + else: + # we use the actual number of steps for training + num_steps = range(len(test_dataloader)) + data_iter = iter(test_dataloader) + progress = tqdm(num_steps) + + for _ in progress: + if args.synthetic: + # generate fake data + img, label = synthesize_data() + else: + # get the real data + img, label = next(data_iter) + + img = img.cuda() + label = label.cuda() + + with torch.no_grad(): + output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + test_loss = criterion(output, label) + pred = torch.argmax(output, dim=-1) + correct += torch.sum(pred == label) + total += img.size(0) + + logger.info( + f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", + ranks=[0]) + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69859f885ae601efcec049b4771df831ed0da009 --- /dev/null +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -0,0 +1,170 @@ +import time +from copy import deepcopy +from functools import partial +from typing import Callable, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torchvision.models as tm +from transformers import GPT2Config, GPT2LMHeadModel + +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace + + +def bench(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5) -> Tuple[int, int]: + """Benchmarking a given graph module + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + Returns: + Tuple[int, int]: peak memory in MB and step time in MS. + """ + gm.train() + gm.cuda() + step_time = float('inf') + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + cached = torch.cuda.max_memory_allocated(device="cuda") + try: + for _ in range(num_steps): + args, label = data_gen() + output, loss = None, None + + torch.cuda.synchronize(device="cuda") + start = time.time() + output = gm(*args) + loss = criterion(output, label) + loss.backward() + torch.cuda.synchronize(device="cuda") + step_time = min(step_time, time.time() - start) + + for child in gm.children(): + for param in child.parameters(): + param.grad = None + del args, label, output, loss + except: + del args, label, output, loss + gm.to("cpu") + torch.cuda.empty_cache() + peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2 + return peak_mem, step_time * 1.0e3 + + +def bench_rotor(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5, + sample_points: int = 20, + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4) -> Tuple[np.array, list, list]: + """Auto Checkpoint Rotor Algorithm benchmarking + Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + sample_points (int, optional): Number of sample points. Defaults to 20. + free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0]. + start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget + will be free_memory / start_factor. Defaults to 4. + Returns: + Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS). + """ + peak_hist, step_hist = [], [] + raw_graph = deepcopy(gm.graph) + for budget in np.linspace(free_memory // start_factor, free_memory, sample_points): + gm = metainfo_trace(gm, *data_gen()[0]) + solver = CheckpointSolverRotor(gm.graph, free_memory=budget) + try: + gm.graph = solver.solve(verbose=False) + peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) + except: + peak_memory, step_time = budget / 1024**2, float('inf') + peak_hist.append(peak_memory) + step_hist.append(step_time) + gm.graph = deepcopy(raw_graph) + return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist + + +class GPTLMModel(nn.Module): + """ + GPT Model + """ + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + """ + GPT Loss + """ + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_6b(checkpoint=False): + return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): + """ + Generate random data for gpt2 benchmarking + """ + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return (input_ids, attention_mask), attention_mask + + +def data_gen_resnet(batch_size, shape, device='cuda:0'): + """ + Generate random data for resnet benchmarking + """ + data = torch.empty(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return (data,), label diff --git a/examples/tutorial/auto_parallel/config.py b/examples/tutorial/auto_parallel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fa14eda740f7e200a750d1e3bc9f806cb45b55fe --- /dev/null +++ b/examples/tutorial/auto_parallel/config.py @@ -0,0 +1,2 @@ +BATCH_SIZE = 128 +NUM_EPOCHS = 10 diff --git a/examples/tutorial/download_cifar10.py b/examples/tutorial/download_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6b6988ade531f9a6e77955803b7c2dbd88ca9a --- /dev/null +++ b/examples/tutorial/download_cifar10.py @@ -0,0 +1,13 @@ +import os + +from torchvision.datasets import CIFAR10 + + +def main(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + data_root = os.path.join(dir_path, 'data') + dataset = CIFAR10(root=data_root, download=True) + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/hybrid_parallel/README.md b/examples/tutorial/hybrid_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6f975e86330a2276caad064f8c33672b7474f7c8 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/README.md @@ -0,0 +1,45 @@ +# Multi-dimensional Parallelism with Colossal-AI + + +## ๐Ÿš€Quick Start +1. Install our model zoo. +```bash +pip install titans +``` +2. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag. +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + +3. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs. + + +## Install Titans Model Zoo + +```bash +pip install titans +``` + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`. +The dataset will be downloaded to `colossalai/examples/tutorials/data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + + +## Run on 2*2 device mesh + +Current configuration setting on `config.py` is TP=2, PP=2. + +```bash +# train with cifar10 +colossalai run --nproc_per_node 4 train.py --config config.py + +# train with synthetic data +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2450ab1c7a7238fb1fb3d04ce4bc955e99791a72 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 3 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2a207cb17294260ca9ecdf0389afdc67046d62 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/train.py @@ -0,0 +1,145 @@ +import os + +import torch +from titans.dataloader.cifar10 import build_cifar +from titans.model.vit.vit import _create_vit_model +from tqdm import tqdm + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.utils import get_dataloader, is_using_pp + + +class DummyDataloader(): + + def __init__(self, length, batch_size): + self.length = length + self.batch_size = batch_size + + def generate(self): + data = torch.rand(self.batch_size, 3, 224, 224) + label = torch.randint(low=0, high=10, size=(self.batch_size,)) + return data, label + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data') + if args.synthetic: + # if we use synthetic dataset + # we train for 30 steps and eval for 10 steps per epoch + train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE) + test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) + else: + train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + data_iter = iter(train_dataloader) + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/large_batch_optimizer/README.md b/examples/tutorial/large_batch_optimizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..20bddb3834348e75423237a77151f68f8b747e23 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/README.md @@ -0,0 +1,31 @@ +# Comparison of Large Batch Training Optimization + +## ๐Ÿš€Quick Start +Run with synthetic data +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`. +The dataset will be downloaded to `colossalai/examples/tutorials/data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + +You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command. + + +## Run on 2*2 device mesh + +```bash +# run with cifar10 +colossalai run --nproc_per_node 4 train.py --config config.py + +# run with synthetic dataset +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e019154e4b127c5f3c8f54f8e411c818933043a5 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 512 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 3 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d403c275d1af466a5ee72db9ccab257a6baa65c8 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -0,0 +1,144 @@ +import os + +import torch +from titans.dataloader.cifar10 import build_cifar +from titans.model.vit.vit import _create_vit_model +from tqdm import tqdm + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import Lamb, Lars +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.utils import get_dataloader, is_using_pp + + +class DummyDataloader(): + + def __init__(self, length, batch_size): + self.length = length + self.batch_size = batch_size + + def generate(self): + data = torch.rand(self.batch_size, 3, 224, 224) + label = torch.randint(low=0, high=10, size=(self.batch_size,)) + return data, label + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data/') + if args.synthetic: + train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE) + test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) + else: + train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + data_iter = iter(train_dataloader) + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/opt/inference/README.md b/examples/tutorial/opt/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5bacac0d74ad93ad2f9709c0d5a5baad51a1b453 --- /dev/null +++ b/examples/tutorial/opt/inference/README.md @@ -0,0 +1,88 @@ +# Overview + +This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI. + +It supports tensor parallelism, batching and caching. + +## ๐Ÿš€Quick Start +1. Run inference with OPT 125M +```bash +docker hpcaitech/tutorial:opt-inference +docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-inference +``` +2. Start the http server inside the docker container with tensor parallel size 2 +```bash +python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m +``` + +# How to run + +Run OPT-125M: +```shell +python opt_fastapi.py opt-125m +``` + +It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs. + +## Configure + +### Configure model +```shell +python opt_fastapi.py +``` +Available models: opt-125m, opt-6.7b, opt-30b, opt-175b. + +### Configure tensor parallelism +```shell +python opt_fastapi.py --tp +``` +The `` can be an integer in `[1, #GPUs]`. Default `1`. + +### Configure checkpoint +```shell +python opt_fastapi.py --checkpoint +``` +The `` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded. + +### Configure queue +```shell +python opt_fastapi.py --queue_size +``` +The `` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406). + +### Configure bathcing +```shell +python opt_fastapi.py --max_batch_size +``` +The `` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value. + +Note that the batch size is not always equal to ``, as some consecutive requests may not be batched. + +### Configure caching +```shell +python opt_fastapi.py --cache_size --cache_list_size +``` +This will cache `` unique requests. And for each unique request, it cache `` different results. A random result will be returned if the cache is hit. + +The `` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `` can be an integer in `[1, MAXINT]`. + +### Other configurations +```shell +python opt_fastapi.py -h +``` + +# How to benchmark +```shell +cd benchmark +locust +``` + +Then open the web interface link which is on your console. + +# Pre-process pre-trained weights + +## OPT-66B +See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py). + +## OPT-175B +See [script/process-opt-175b](./script/process-opt-175b/). \ No newline at end of file diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0876ca833890fd2bfac3b4d9f342a05c67928f --- /dev/null +++ b/examples/tutorial/opt/inference/batch.py @@ -0,0 +1,59 @@ +import torch +from typing import List, Deque, Tuple, Hashable, Any +from energonai import BatchManager, SubmitEntry, TaskEntry + + +class BatchManagerForGeneration(BatchManager): + def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.pad_token_id = pad_token_id + + def _left_padding(self, batch_inputs): + max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) + outputs = {'input_ids': [], 'attention_mask': []} + for inputs in batch_inputs: + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + padding_len = max_len - len(input_ids) + input_ids = [self.pad_token_id] * padding_len + input_ids + attention_mask = [0] * padding_len + attention_mask + outputs['input_ids'].append(input_ids) + outputs['attention_mask'].append(attention_mask) + for k in outputs: + outputs[k] = torch.tensor(outputs[k]) + return outputs, max_len + + @staticmethod + def _make_batch_key(entry: SubmitEntry) -> tuple: + data = entry.data + return (data['top_k'], data['top_p'], data['temperature']) + + def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: + entry = q.popleft() + uids = [entry.uid] + batch = [entry.data] + while len(batch) < self.max_batch_size: + if len(q) == 0: + break + if self._make_batch_key(entry) != self._make_batch_key(q[0]): + break + if q[0].data['max_tokens'] > entry.data['max_tokens']: + break + e = q.popleft() + batch.append(e.data) + uids.append(e.uid) + inputs, max_len = self._left_padding(batch) + trunc_lens = [] + for data in batch: + trunc_lens.append(max_len + data['max_tokens']) + inputs['top_k'] = entry.data['top_k'] + inputs['top_p'] = entry.data['top_p'] + inputs['temperature'] = entry.data['temperature'] + inputs['max_tokens'] = max_len + entry.data['max_tokens'] + return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + + def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: + retval = [] + for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens): + retval.append((uid, output[:trunc_len])) + return retval diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py new file mode 100644 index 0000000000000000000000000000000000000000..4d829e5d83bf73c45b20d05bfa549b5ee3b869ba --- /dev/null +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -0,0 +1,15 @@ +from locust import HttpUser, task +from json import JSONDecodeError + + +class GenerationUser(HttpUser): + @task + def generate(self): + prompt = 'Question: What is the longest river on the earth? Answer:' + for i in range(4, 9): + data = {'max_tokens': 2**i, 'prompt': prompt} + with self.client.post('/generation', json=data, catch_response=True) as response: + if response.status_code in (200, 406): + response.success() + else: + response.failure('Response wrong') diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..30febc44fbb3d7afb703c641e582562f90c6b8d0 --- /dev/null +++ b/examples/tutorial/opt/inference/cache.py @@ -0,0 +1,64 @@ +from collections import OrderedDict +from threading import Lock +from contextlib import contextmanager +from typing import List, Any, Hashable, Dict + + +class MissCacheError(Exception): + pass + + +class ListCache: + def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: + """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. + When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. + + Args: + cache_size (int): Max size for LRU cache. + list_size (int): Value list size. + fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to []. + """ + self.cache_size = cache_size + self.list_size = list_size + self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict() + self.fixed_cache: Dict[Hashable, List[Any]] = {} + for key in fixed_keys: + self.fixed_cache[key] = [] + self._lock = Lock() + + def get(self, key: Hashable) -> List[Any]: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) >= self.list_size: + return l + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) >= self.list_size: + return l + raise MissCacheError() + + def add(self, key: Hashable, value: Any) -> None: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + else: + if len(self.cache) >= self.cache_size: + self.cache.popitem(last=False) + self.cache[key] = [value] + + @contextmanager + def lock(self): + try: + self._lock.acquire() + yield + finally: + self._lock.release() diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfc2a22e7c0c98070d940f3a42cfd66ce839028 --- /dev/null +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -0,0 +1,123 @@ +import argparse +import logging +import random +from typing import Optional + +import uvicorn +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel, Field +from transformers import GPT2Tokenizer + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = FastAPI() + + +@app.post('/generation') +async def generate(data: GenerationTaskReq, request: Request): + logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') + key = (data.prompt, data.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(data.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = data.max_tokens + inputs['top_k'] = data.top_k + inputs['top_p'] = data.top_p + inputs['temperature'] = data.temperature + try: + uid = id(data) + engine.submit(uid, inputs) + output = await engine.wait(uid) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + raise HTTPException(status_code=406, detail=e.args[0]) + + return {'text': output} + + +@app.on_event("shutdown") +async def shutdown(*_): + engine.shutdown() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: ๆˆ‘ไปŠๅคฉๅพˆๅผ€ๅฟƒใ€‚\n\nEnglish: I am going to play basketball.\nChinese: ๆˆ‘ไธ€ไผšๅŽปๆ‰“็ฏฎ็ƒใ€‚\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + config = uvicorn.Config(app, host=args.http_host, port=args.http_port) + server = uvicorn.Server(config=config) + server.run() diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py new file mode 100644 index 0000000000000000000000000000000000000000..8dab82622c59c313aee836a4509a27fd48ccad1c --- /dev/null +++ b/examples/tutorial/opt/inference/opt_server.py @@ -0,0 +1,122 @@ +import logging +import argparse +import random +from torch import Tensor +from pydantic import BaseModel, Field +from typing import Optional +from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B +from transformers import GPT2Tokenizer +from energonai import launch_engine, QueueFullError +from sanic import Sanic +from sanic.request import Request +from sanic.response import json +from sanic_ext import validate, openapi +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = Sanic('opt') + + +@app.post('/generation') +@openapi.body(GenerationTaskReq) +@validate(json=GenerationTaskReq) +async def generate(request: Request, body: GenerationTaskReq): + logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}') + key = (body.prompt, body.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(body.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = body.max_tokens + inputs['top_k'] = body.top_k + inputs['top_p'] = body.top_p + inputs['temperature'] = body.temperature + try: + uid = id(body) + engine.submit(uid, inputs) + output = await engine.wait(uid) + assert isinstance(output, Tensor) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + return json({'detail': e.args[0]}, status=406) + + return json({'text': output}) + + +@app.after_server_stop +def shutdown(*_): + engine.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: ๆˆ‘ไปŠๅคฉๅพˆๅผ€ๅฟƒใ€‚\n\nEnglish: I am going to play basketball.\nChinese: ๆˆ‘ไธ€ไผšๅŽปๆ‰“็ฏฎ็ƒใ€‚\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/opt/inference/requirements.txt b/examples/tutorial/opt/inference/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0970d587f3cae693cbf63f572bd30f83b385ea9 --- /dev/null +++ b/examples/tutorial/opt/inference/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.85.1 +locust==2.11.0 +pydantic==1.10.2 +sanic==22.9.0 +sanic_ext==22.9.0 +torch>=1.10.0 +transformers==4.23.1 +uvicorn==0.19.0 diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc3cba72df33c3242ed35625e3467aaefb4849e1 --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -0,0 +1,46 @@ +# Process OPT-175B weights + +You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this. + +First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`. + +Then, `cd metaseq`. + +To consolidate checkpoints to eliminate FSDP: + +```shell +bash metaseq/scripts/reshard_mp_launch_no_slurm.sh /checkpoint_last / 8 1 +``` + +You will get 8 files in ``, and you should have the following checksums: +``` +7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt +c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt +45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt +abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt +05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt +d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt +fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt +2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt +``` + +Copy `flat-meta.json` to ``. + +Then cd to this dir, and we unflatten parameters. + +```shell +bash unflat.sh / / +``` + +Finally, you will get 8 files in `` with following checksums: +``` +6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt +58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt +69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt +002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt +6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt +93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt +5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt +f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt +``` + diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..a17ddd4fa1735fcd1a114f1fd3c87870c257d6bd --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -0,0 +1,55 @@ +import argparse +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch + + +def load_json(path: str): + with open(path) as f: + return json.load(f) + + +def parse_shape_info(flat_dir: str): + data = load_json(os.path.join(flat_dir, 'shape.json')) + flat_info = defaultdict(lambda: defaultdict(list)) + for k, shape in data.items(): + matched = re.match(r'decoder.layers.\d+', k) + if matched is None: + flat_key = 'flat_param_0' + else: + flat_key = f'{matched[0]}.flat_param_0' + flat_info[flat_key]['names'].append(k) + flat_info[flat_key]['shapes'].append(shape) + flat_info[flat_key]['numels'].append(int(np.prod(shape))) + return flat_info + + +def convert(flat_dir: str, output_dir: str, part: int): + flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') + output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') + flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_sd = torch.load(flat_path) + print(f'Loaded flat state dict from {flat_path}') + output_sd = {} + for flat_key, param_meta in flat_meta.items(): + flat_param = flat_sd['model'][flat_key] + assert sum(param_meta['numels']) == flat_param.numel( + ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' + for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + output_sd[name] = param.view(shape) + + torch.save(output_sd, output_path) + print(f'Saved unflat state dict to {output_path}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('flat_dir') + parser.add_argument('output_dir') + parser.add_argument('part', type=int) + args = parser.parse_args() + convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json new file mode 100644 index 0000000000000000000000000000000000000000..59d285565cfdb8d68ce0e5dc91aab2046c136e9c --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -0,0 +1 @@ +{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh b/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc5c190e24e82cc1e0b6d6177b0fe691a5942385 --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +for i in $(seq 0 7); do + python convert_ckpt.py $1 $2 ${i} & +done + +wait $(jobs -p) diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py new file mode 100644 index 0000000000000000000000000000000000000000..0494647d7bcce31a6ba1e3decf0943043730c2fa --- /dev/null +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -0,0 +1,55 @@ +import os +import torch +from multiprocessing import Pool + +# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main +# you can use whether wget or git lfs + +path = "/path/to/your/ckpt" +new_path = "/path/to/the/processed/ckpt/" + +assert os.path.isdir(path) +files = [] +for filename in os.listdir(path): + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + files.append(filepath) + +with Pool(14) as pool: + ckpts = pool.map(torch.load, files) + +restored = {} +for ckpt in ckpts: + for k,v in ckpt.items(): + if(k[0] == 'm'): + k = k[6:] + if(k == "lm_head.weight"): + k = "head.dense.weight" + if(k == "decoder.final_layer_norm.weight"): + k = "decoder.layer_norm.weight" + if(k == "decoder.final_layer_norm.bias"): + k = "decoder.layer_norm.bias" + restored[k] = v +restored["decoder.version"] = "0.0" + + +split_num = len(restored.keys()) // 60 +count = 0 +file_count = 1 +tmp = {} +for k,v in restored.items(): + print(k) + tmp[k] = v + count = count + 1 + if(count == split_num): + filename = str(file_count) + "-restored.pt" + torch.save(tmp, os.path.join(new_path, filename)) + file_count = file_count + 1 + count = 0 + tmp = {} + +filename = str(file_count) + "-restored.pt" +torch.save(tmp, os.path.join(new_path, filename)) + + + diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a01209cbda0e58f931289e794ac0f79b9e7f4b47 --- /dev/null +++ b/examples/tutorial/opt/opt/README.md @@ -0,0 +1,76 @@ + +# Train OPT model with Colossal-AI + + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## ๐Ÿš€Quick Start for Tutorial +1. Install the dependency +```bash +pip install datasets accelerate +``` +2. Run finetuning with synthetic datasets with one GPU +```bash +bash ./run_clm_synthetic.sh +``` +3. Run finetuning with 4 GPUs +```bash +bash ./run_clm_synthetic.sh 16 0 125m 4 +``` + +## Quick Start for Practical Use +You can launch training by using the following bash script + +```bash +bash ./run_clm.sh +``` + +- batch-size-per-gpu: number of samples fed to each GPU, default is 16 +- mem-cap: limit memory usage within a value in GB, default is 0 (no limit) +- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request +the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT). +- gpu-num: the number of GPUs to use, default is 1. + +It uses `wikitext` dataset. + +To use synthetic dataset: + +```bash +bash ./run_clm_synthetic.sh +``` + +## Remarkable Performance +On a single GPU, Colossal-AIโ€™s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed. +Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale. + +

+ +

+ +Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI! + +More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d), +and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon. diff --git a/examples/tutorial/opt/opt/benchmark.sh b/examples/tutorial/opt/opt/benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..f02f7629ad16f42fc04f131c26869477fc59aa90 --- /dev/null +++ b/examples/tutorial/opt/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM +done +done +done +done diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..833745f3e8d84ef76305ebbecc09822746243be3 --- /dev/null +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -0,0 +1,6 @@ +from colossalai.zero.shard_utils import TensorShardStrategy + +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + tensor_placement_policy="auto", + reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py new file mode 100644 index 0000000000000000000000000000000000000000..95f0abf1d8c92ed5766e5f0fa2c70618be7827c5 --- /dev/null +++ b/examples/tutorial/opt/opt/context.py @@ -0,0 +1,32 @@ +import torch.distributed as dist + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class barrier_context(): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + parallel_mode (ParallelMode): the parallel mode corresponding to a process group + Usage: + with barrier_context(): + dataset = CIFAR10(root='./data', download=True) + """ + + def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): + # the class name is lowercase by convention + current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) + self.should_block = current_rank != executor_rank + self.group = gpc.get_group(parallel_mode=parallel_mode) + + def __enter__(self): + if self.should_block: + dist.barrier(group=self.group) + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.should_block: + dist.barrier(group=self.group) diff --git a/examples/tutorial/opt/opt/requirements.txt b/examples/tutorial/opt/opt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c34df7992d3f1b6fa19cd58ac36c6a9b3499c75f --- /dev/null +++ b/examples/tutorial/opt/opt/requirements.txt @@ -0,0 +1,6 @@ +colossalai +torch >= 1.8.1 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +accelerate == 0.13.2 diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f576cb18aadc658d0b52e467eaf3c24e0b5fa9 --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm.py @@ -0,0 +1,636 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import math +import os +import time +from itertools import chain + +import datasets +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from context import barrier_context +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +import colossalai +import transformers +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoTokenizer, + GPT2Tokenizer, + OPTForCausalLM, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils.versions import require_version + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument("-s", "--synthetic", action="store_true") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=None, + help="A csv or a json file containing the training data.") + parser.add_argument("--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data.") + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the ๐Ÿค— Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument("--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)."), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument("--overwrite_cache", + type=bool, + default=False, + help="Overwrite the cached training and evaluation sets") + parser.add_argument("--no_keep_linebreaks", + action="store_true", + help="Do not keep line breaks when using TXT files.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed."), + ) + + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + # Sanity checks + if not args.synthetic: + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +class DummyDataloader: + + def __init__(self, length, batch_size, seq_len, vocab_size): + self.length = length + self.batch_size = batch_size + self.seq_len = seq_len + self.vocab_size = vocab_size + + def generate(self): + input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch(config=dict()) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # Handle the repository creation + with barrier_context(): + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + logger.info("Start preparing dataset", ranks=[0]) + if not args.synthetic: + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + logger.info("Dataset is prepared", ranks=[0]) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.model_name_or_path == 'facebook/opt-13b': + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + else: + print(f'load model from {args.model_name_or_path}') + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + PLACEMENT_POLICY = 'auto' + cai_version = colossalai.__version__ + logger.info(f'using Colossal-AI version {cai_version}') + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + + if not args.synthetic: + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i:i + block_size] for i in range(0, total_length, block_size) + ] for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + if not args.synthetic: + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = get_dataloader(train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, + collate_fn=default_data_collator, + batch_size=args.per_device_eval_batch_size) + else: + train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings, + config.vocab_size) + eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings, + config.vocab_size) + logger.info("Dataloaders have been created", ranks=[0]) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA) + num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size + num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {num_train_samples}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + completed_steps = 0 + starting_epoch = 0 + global_step = 0 + + for epoch in range(starting_epoch, args.num_train_epochs): + + if completed_steps >= args.max_train_steps: + break + + model.train() + for step, batch in enumerate(train_dataloader): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(use_cache=False, **batch) + loss = outputs['loss'] + optimizer.backward(loss) + + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + global_step += 1 + logger.info("Global step {} finished".format(global_step + 1), ranks=[0]) + + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + + loss = outputs['loss'].unsqueeze(0) + losses.append(loss) + + losses = torch.cat(losses) + losses = losses[:num_eval_samples] + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) + + if args.output_dir is not None: + model_state = model.state_dict() + if is_main_process: + torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + dist.barrier() + # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + # model.load_state_dict(load_state, strict=False) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/opt/opt/run_clm.sh b/examples/tutorial/opt/opt/run_clm.sh new file mode 100644 index 0000000000000000000000000000000000000000..858d3325a7b4f9f55592ac4a7d62836f2a0a0501 --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm.sh @@ -0,0 +1,22 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/opt/opt/run_clm_synthetic.sh b/examples/tutorial/opt/opt/run_clm_synthetic.sh new file mode 100644 index 0000000000000000000000000000000000000000..80435f16ce2d24b634bdb16017092c5a32ea599b --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm_synthetic.sh @@ -0,0 +1,21 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + -s \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/sequence_parallel/README.md b/examples/tutorial/sequence_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7058f53db8b67590dc1c2948c8c66a6358d811e0 --- /dev/null +++ b/examples/tutorial/sequence_parallel/README.md @@ -0,0 +1,151 @@ +# Sequence Parallelism with BERT + +In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate +activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length. + +Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) + +## ๐Ÿš€Quick Start +1. Run with the following command +```bash +export PYTHONPATH=$PWD +colossalai run --nproc_per_node 4 train.py -s +``` +2. The default config is sequence parallel size = 2, pipeline size = 1, letโ€™s change pipeline size to be 2 and try it again. + + +## How to Prepare WikiPedia Dataset + +First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items: +1. raw WikiPedia dataset +2. wikipedia extractor (extract data from the raw dataset) +3. vocabulary file +4. preprocessing scripts (generate final data from extracted data) + +For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file. + +```python +# download raw data +mkdir data && cd ./data +wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 + +# install wiki extractor +git clone https://github.com/FrankLeeeee/wikiextractor.git +pip install ./wikiextractor + +# extractmodule +wikiextractor --json enwiki-latest-pages-articles.xml.bz2 +cat text/*/* > ./corpus.json +cd .. + +# download vocab file +mkdir vocab && cd ./vocab +wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt +cd .. + +# preprocess some data +git clone https://github.com/NVIDIA/Megatron-LM.git +cd ./Megatron-LM +python tools/preprocess_data.py \ + --input ../data/corpus.json \ + --output-prefix my-bert \ + --vocab ../vocab/bert-large-uncased-vocab.txt \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences \ + --workers 24 +``` + +After running the preprocessing scripts, you will obtain two files: +1. my-bert_text_sentence.bin +2. my-bert_text_sentence.idx + +If you happen to encouter `index out of range` problem when running Megatron's script, +this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below: + +```python +class Encoder(object): + def __init__(self, args): + ... + + def initializer(self): + ... + + def encode(self, json_line): + data = json.loads(json_line) + ids = {} + for key in self.args.json_keys: + text = data[key] + doc_ids = [] + + # lsg: avoid sentences which start with a punctuation + # as it cannot be tokenized by splitter + if len(text) > 0 and text[0] in string.punctuation: + text = text[1:] + + for sentence in Encoder.splitter.tokenize(text): + sentence_ids = Encoder.tokenizer.tokenize(sentence) + if len(sentence_ids) > 0: + doc_ids.append(sentence_ids) + if len(doc_ids) > 0 and self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eod) + ids[key] = doc_ids + return ids, len(json_line) +``` + +## How to Train with Sequence Parallelism + +We provided `train.py` for you to execute training. Before invoking the script, there are several +steps to perform. + +### Step 1. Set data path and vocab path + +At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`. + +```python +DATA_PATH = +VOCAB_FILE_PATH = +``` + +`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension. + +For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set + +```python +DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence' +``` + +The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset +(e.g. bert-large-uncased-vocab.txt). + +### Step 3. Make Dataset Helper + +Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`. + +```python +cd ./data/datasets +make +``` + +### Step 3. Configure your parameters + +In the `config.py` provided, a set of parameters are defined including training scheme, model, etc. +You can also modify the ColossalAI setting. For example, if you wish to parallelize over the +sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=`. + +### Step 4. Invoke parallel training + +Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your +machine setting. + +- If you are using a single machine with multiple GPUs, PyTorch launch utility can easily let you + start your script. A sample command is like below: + + ```bash + colossalai run --nproc_per_node --master_addr localhost --master_port 29500 train.py + ``` + +- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai + launch_from_slurm` or `colossalai.launch_from_openmpi` as it is easier to use SLURM and OpenMPI + to start multiple processes over multiple nodes. If you have your own launcher, you can fall back + to the default `colossalai.launch` function. diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..df0c5282f03243d14698704c252a953816d83332 --- /dev/null +++ b/examples/tutorial/sequence_parallel/config.py @@ -0,0 +1,38 @@ +from colossalai.amp import AMP_TYPE + +DATA_PATH = '' +VOCAB_FILE_PATH = '' + +# hyper-parameters +TRAIN_ITERS = 1000000 +DECAY_ITERS = 990000 +WARMUP_FRACTION = 0.01 +GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU +EVAL_ITERS = 10 +EVAL_INTERVAL = 10 +LR = 0.0001 +MIN_LR = 1e-05 +WEIGHT_DECAY = 0.01 +SEQ_LENGTH = 512 + +# BERT config +DEPTH = 12 +NUM_ATTENTION_HEADS = 12 +HIDDEN_SIZE = 768 + +# model config +ADD_BINARY_HEAD = False + +# random seed +SEED = 1234 + +# pipeline config +# only enabled when pipeline > 1 +NUM_MICRO_BATCHES = 4 + +# colossalai config +parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence')) + +fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True) + +gradient_handler = [dict(type='SequenceParallelGradientHandler')] diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef2d999389fe001b01342e66942c69455327efb --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -0,0 +1,102 @@ +from colossalai.context.parallel_context import ParallelContext +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.context import ParallelMode +from .datasets.data_samplers import build_pretraining_data_loader +from .datasets.builder import build_train_valid_test_datasets +import torch + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x + + +def build_train_valid_test_data_iterators(train_iters, + global_batch_size, + eval_interval, + eval_iters, + dataloader_type='single', + **kwargs + ): + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + logger = get_dist_logger() + logger.info('> building train, validation, and test datasets ...', ranks=[0]) + + # Backward compatibility, assume fixed batch size. + # if iteration > 0 and consumed_train_samples == 0: + # assert train_samples is None, \ + # 'only backward compatibility support for iteration-based training' + # consumed_train_samples = iteration * global_batch_size + # if iteration > 0 and consumed_valid_samples == 0: + # if train_samples is None: + # consumed_valid_samples = (iteration // eval_interval) * \ + # eval_iters * global_batch_size + + # Data loader only on rank 0 of each model parallel group. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + + # Number of train/valid/test samples. + train_samples = train_iters * global_batch_size + eval_iters_ = (train_iters // eval_interval + 1) * eval_iters + test_iters = eval_iters + train_val_test_num_samples = [train_samples, + eval_iters_ * global_batch_size, + test_iters * global_batch_size] + logger.info(' > datasets target sizes (minimum size):') + logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) + logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) + logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0]) + + # Build the datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + train_valid_test_num_samples=train_val_test_num_samples, **kwargs) + + # Build dataloaders. + dp_size = gpc.get_world_size(ParallelMode.DATA) + train_dataloader = build_pretraining_data_loader( + train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + valid_dataloader = build_pretraining_data_loader( + valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and train_iters > 0 + do_valid = valid_dataloader is not None and eval_iters > 0 + do_test = test_dataloader is not None and eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor( + [int(do_train), int(do_valid), int(do_test)]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast(flags, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Build iterators. + dl_type = dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader is not None: + train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader)) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(valid_dataloader)) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(test_dataloader)) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..d092db3e7dd8d545253e3a36c6203ace3d0eec9d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -0,0 +1,165 @@ +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +import torch + +_MAX_DATA_DIM = 5 + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data dictionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, + data) + + # Pack on rank zero. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # Check that all keys have the same data type. + # Flatten the data associated with the keys + flatten_data = torch.cat( + [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, + device=torch.cuda.current_device(), + dtype=datatype) + + # Broadcast + torch.distributed.broadcast(flatten_data, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output + + +def get_batch(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b['text'].long() + types = data_b['types'].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'].float() + lm_labels = data_b['labels'].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +def get_batch_for_sequence_parallel(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + # unpack + data_b = broadcast_data(keys, data, datatype) + + # # get tensor parallel local rank + global_rank = torch.distributed.get_rank() + local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) + local_rank = global_rank % local_world_size + seq_length = data_b['text'].size(1) + sub_seq_length = seq_length // local_world_size + sub_seq_start = local_rank * sub_seq_length + sub_seq_end = (local_rank+1) * sub_seq_length + # + # # Unpack. + tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() + types = data_b['types'][:, sub_seq_start:sub_seq_end].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() + lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +class SequenceParallelDataIterator: + + def __init__(self, data_iter): + self.data_iter = data_iter + + + def __iter__(self): + return self.data_iter + + def __next__(self): + return get_batch_for_sequence_parallel(self.data_iter) \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/data/datasets/Makefile b/examples/tutorial/sequence_parallel/data/datasets/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..8f9db7686696fbea6c94b998db4b40ef426c748d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/tutorial/sequence_parallel/data/datasets/__init__.py b/examples/tutorial/sequence_parallel/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5f898c6bdf89c6cf0243af102d04f6efed86b8 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/__init__.py @@ -0,0 +1 @@ +from . import indexed_dataset diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d6388bd9f8e427575f345c20f38aa276a97be049 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT Style dataset.""" + +import os +import time + +import numpy as np +import torch +from torch.utils.data import Dataset + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger + +from ..tokenizer import get_tokenizer +from .dataset_utils import ( + create_masked_lm_predictions, + create_tokens_and_tokentypes, + get_a_and_b_segments, + pad_and_convert_to_numpy, + truncate_segments, +) + +try: + from . import helpers +except: + print("helper is not built, ignore this message if you are using synthetic data.") + + +class BertDataset(Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, + short_seq_prob, seed, binary_head): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping_( + self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens, + short_seq_prob, + self.seed, + self.name, + self.binary_head) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) + return build_training_sample( + sample, + seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, + self.sep_id, + self.mask_id, + self.pad_id, + self.masked_lm_prob, + np_rng, + self.binary_head) + + +def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, + seed, name, binary_head): + logger = get_dist_logger() + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) + indexmap_filename += '_{}s'.format(seed) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) + # First compile and then import. + samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, + max_num_samples, max_seq_length, short_seq_prob, seed, verbose, + 2 if binary_head else 1) + logger.info('\n > done building samples index maping', ranks=[0]) + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) + # Make sure all the ranks have built the mapping + logger.info('\n > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format(time.time() - start_time), + ranks=[0]) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) + if gpc.is_initialized(ParallelMode.PIPELINE): + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) + assert counts[0].item() == (torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) + + # Load indexed dataset. + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + + '\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + + '\n total number of samples: {}'.format(samples_mapping.shape[0]), + ranks=[0]) + + return samples_mapping + + +def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, + sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + """ + + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng) + + # Build tokens and toketypes. + tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id) + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, + _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, + mask_id, max_predictions_per_seq, np_rng) + + # Padding. + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ + = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length) + + train_sample = { + 'text': tokens_np, + 'types': tokentypes_np, + 'labels': labels_np, + 'is_random': int(is_next_random), + 'loss_mask': loss_mask_np, + 'padding_mask': padding_mask_np, + 'truncated': int(truncated) + } + return train_sample diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6a06c869d8c808af92ed3a6993a57cca9ca78a8b --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time + +import numpy as np +import torch + + +class BlendableDataset(torch.utils.data.Dataset): + + def __init__(self, datasets, weights): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indices. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from . import helpers + helpers.build_blending_indices(self.dataset_index, + self.dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6106f833b4628a0763cae82d0a7b6073f5c0548d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -0,0 +1,152 @@ +from .blendable_dataset import BlendableDataset +from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ +from .bert_dataset import BertDataset +from colossalai.logging import get_dist_logger + +DSET_TYPE_BERT = 'standard_bert' +DSET_TYPE_ICT = 'ict' +DSET_TYPE_T5 = 't5' + +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is designed to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + logger = get_dist_logger() + + # Print stats about the splits. + logger.info('\n > dataset split:', ranks=[0]) + + def print_split_stats(name, index): + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logger.info('\n {}:'.format(name) + + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], splits[index + 1], + splits[index + 1] - splits[index]) + + '\n sentence indices in [{}, {}) total of {} sentences'.format( + start_index, end_index, + end_index - start_index), + ranks=[0]) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + ) + + if dataset_type != DSET_TYPE_BERT: + raise NotImplementedError("Only BERT dataset is supported") + else: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs + ) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..cf547ad9755815dfc0e9de449d147fe928a82948 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -0,0 +1,153 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataloaders.""" + +import torch +import random +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): + """Build dataloader given an input dataset.""" + + if dataset is None: + return None + + # Megatron sampler + if dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + elif dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + else: + raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) + + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) + + +class MegatronPretrainingSampler: + + def __init__(self, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class MegatronPretrainingRandomSampler: + + def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.last_batch_size = \ + self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ + * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4e4763fc107ca9da3f063037bf08f4efc67cdd --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -0,0 +1,592 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import math +import time +import collections +from colossalai.logging import get_dist_logger +import numpy as np +from .blendable_dataset import BlendableDataset +from .indexed_dataset import make_dataset as make_indexed_dataset + +DSET_TYPE_STD = 'standard_bert' +DSET_TYPE_ICT = 'ict' + +DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] + + +def get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0]*num_datasets + prefixes = [0]*num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2*i]) + prefixes[i] = (data_prefix[2*i+1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) + for val in train_valid_test_num_samples]) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + import os + import subprocess + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(['make', '-C', path]) + if ret.returncode != 0: + print("Making C++ dataset helpers module failed, exiting.") + import sys + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, 'make sure each sample has at least two sentences.' + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + #print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions(tokens, + vocab_id_list, vocab_id_to_token_dict, + masked_lm_prob, + cls_id, sep_id, mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (do_whole_word_mask and len(cand_indexes) >= 1 and + not is_start_piece(vocab_id_to_token_dict[token])): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, + masked_lm_labels, token_boundary) + + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + # Note(mingdachen): + # By default, we set the probabilities to favor shorter ngram sequences. + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + pvals = 1. / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx:idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + masked_lms = [] + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + n = np_rng.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + assert len(masked_lms) <= num_to_predict + + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + logger = get_dist_logger() + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + title_dataset = get_indexed_dataset_(args.titles_data_path, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is designed to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + logger.info('\n > dataset split:') + + def print_split_stats(name, index): + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logger.info('\n {}:'.format(name) + + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], + splits[index + 1], + splits[index + 1] - splits[index]) + + '\n sentence indices in [{}, {}) total of {} sentences'.format( + start_index, + end_index, + end_index - start_index), + ranks=[0]) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + from .bert_dataset import BertDataset + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + binary_head=binary_head + ) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + dataset = ICTDataset( + block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + **kwargs + ) + else: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + **kwargs + ) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + logger = get_dist_logger() + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + logger.info('\n > building dataset index ...', ranks=[0]) + logger.info('\n > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time), ranks=[0]) + logger.info('\n > indexed dataset stats:' + + '\n number of documents: {}'.format( + indexed_dataset.doc_idx.shape[0] - 1) + + '\n number of sentences: {}'.format( + indexed_dataset.sizes.shape[0]), + ranks=[0] + ) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e45926a976961eb5094658ba478cb697a88c8000 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -0,0 +1,717 @@ +/* + coding=utf-8 + Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6dac35ff9d413898146df5c1cc8553719e142105 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -0,0 +1,156 @@ +import itertools +import random + +import numpy as np +from torch.utils.data import Dataset + +from megatron import get_tokenizer +from megatron import get_args +from megatron.data.dataset_utils import get_indexed_dataset_ +from megatron.data.realm_dataset_utils import get_block_samples_mapping + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_ict_dataset(use_titles=True, query_in_block_prob=1): + """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) + rather than for training, since it is only built with a single epoch sample mapping. + """ + args = get_args() + block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + + kwargs = dict( + name='full', + block_dataset=block_dataset, + title_dataset=titles_dataset, + data_prefix=args.data_path, + num_epochs=1, + max_num_samples=None, + max_seq_length=args.seq_length, + seed=1, + query_in_block_prob=query_in_block_prob, + use_titles=use_titles, + use_one_sent_docs=args.use_one_sent_docs + ) + dataset = ICTDataset(**kwargs) + return dataset + + +class ICTDataset(Dataset): + """Dataset containing sentences and their blocks for an inverse cloze task.""" + def __init__(self, name, block_dataset, title_dataset, data_prefix, + num_epochs, max_num_samples, max_seq_length, query_in_block_prob, + seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + self.name = name + self.seed = seed + self.max_seq_length = max_seq_length + self.query_in_block_prob = query_in_block_prob + self.block_dataset = block_dataset + self.title_dataset = title_dataset + self.rng = random.Random(self.seed) + self.use_titles = use_titles + self.use_one_sent_docs = use_one_sent_docs + + self.samples_mapping = get_block_samples_mapping( + block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + self.tokenizer = get_tokenizer() + self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_list = self.tokenizer.inv_vocab + self.cls_id = self.tokenizer.cls + self.sep_id = self.tokenizer.sep + self.mask_id = self.tokenizer.mask + self.pad_id = self.tokenizer.pad + + def __len__(self): + return len(self.samples_mapping) + + def __getitem__(self, idx): + """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" + sample_data = self.samples_mapping[idx] + start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() + + if self.use_titles: + title = self.title_dataset[int(doc_idx)] + title_pad_offset = 3 + len(title) + else: + title = None + title_pad_offset = 2 + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 + + # randint() is inclusive for Python rng + rand_sent_idx = self.rng.randint(0, len(block) - 1) + + # keep the query in the context query_in_block_prob fraction of the time. + if self.rng.random() < self.query_in_block_prob: + query = block[rand_sent_idx].copy() + else: + query = block.pop(rand_sent_idx) + + # still need to truncate because blocks are concluded when + # the sentence lengths have exceeded max_seq_length. + query = query[:self.max_seq_length - 2] + block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + + query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) + context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) + + query_mask = make_attention_mask(query_tokens, query_tokens) + context_mask = make_attention_mask(context_tokens, context_tokens) + + block_data = sample_data.as_array() + + sample = { + 'query_tokens': query_tokens, + 'query_mask': query_mask, + 'query_pad_mask': query_pad_mask, + 'context_tokens': context_tokens, + 'context_mask': context_mask, + 'context_pad_mask': context_pad_mask, + 'block_data': block_data, + } + + return sample + + def get_block(self, start_idx, end_idx, doc_idx): + """Get the IDs for an evidence block plus the title of the corresponding document""" + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + title = self.title_dataset[int(doc_idx)] + + block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def get_null_block(self): + """Get empty block and title - used in REALM pretraining""" + block, title = [], [] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def concat_and_pad_tokens(self, tokens, title=None): + """Concat with special tokens and pad sequence to self.max_seq_length""" + tokens = list(tokens) + if title is None: + tokens = [self.cls_id] + tokens + [self.sep_id] + else: + title = list(title) + tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] + assert len(tokens) <= self.max_seq_length + + num_pad = self.max_seq_length - len(tokens) + pad_mask = [1] * len(tokens) + [0] * num_pad + tokens += [self.pad_id] * num_pad + + return np.array(tokens), np.array(pad_mask) diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b4febcd822e1d8640a69e2f31287caad0bba39f7 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16 +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float: 4, + np.double: 8 + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' len(ds.doc_idx) - 1: + args.count = len(ds.doc_idx) - 1 + + for i in range(args.count): + start = ds.doc_idx[i] + end = ds.doc_idx[i + 1] + ids = ds[start:end] + print(f"Document {i}:") + print("--------------") + for s in ids: + assert len(s) > 0 + l = s.data.tolist() + text = tokenizer.detokenize(l) + print(text) + print("---") + + +def test_indexed_dataset_get(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + size = ds.sizes[0] + print(f"size: {size}") + full = ds.get(0) + print(full) + # print(tokenizer.detokenize(full.data.tolist())) + print("---") + end = ds.get(0, offset=size - 10) + print(end) + # print(tokenizer.detokenize(end.data.tolist())) + + start = ds.get(0, length=10) + print(start) + # print(tokenizer.detokenize(start.data.tolist())) + + part = ds.get(0, offset=2, length=8) + print(part) + # print(tokenizer.detokenize(part.data.tolist())) + +# def test_albert_dataset(args): +# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) +# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) +# # ds = AlbertDataset(idataset, tokenizer) +# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, +# args.epochs, args.max_num_samples, +# args.masked_lm_prob, args.seq_length, +# args.short_seq_prob, args.seed) +# truncated = 0 +# total = 0 +# for i, s in enumerate(ds): +# ids = s['text'] +# tokens = ds.tokenizer.convert_ids_to_tokens(ids) +# print(tokens) +# if i >= args.count-1: +# exit() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, help='prefix to data files') + parser.add_argument('--dataset-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer']) + parser.add_argument('--count', type=int, default=10, + help='Number of samples/documents to print') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument('--tokenizer-type', type=str, required=True, + choices=['BertWordPieceLowerCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file (if necessary).') + + parser.add_argument('--epochs', type=int, default=5, + help='Number of epochs to plan for') + parser.add_argument('--max-num-samples', type=int, default=None, + help='Maximum number of samples to plan for') + parser.add_argument('--masked-lm-prob', type=float, default=0.15, + help='probability of masking tokens') + parser.add_argument('--seq-length', type=int, default=512, + help='maximum sequence length') + parser.add_argument('--short-seq-prob', type=float, default=0.1, + help='probability of creating a short sequence') + parser.add_argument('--seed', type=int, default=1234, + help='random seed') + args = parser.parse_args() + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + + if args.dataset_impl == "infer": + args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) + +# test_albert_dataset(args) + test_indexed_dataset_get(args) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh b/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..d121c85958ff35e37431befdceabb831c8cd2705 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +IMPL=cached +python ../preprocess_data.py \ + --input test_samples.json \ + --vocab vocab.txt \ + --dataset-impl ${IMPL} \ + --output-prefix test_samples_${IMPL} \ + --workers 1 \ + --log-interval 2 diff --git a/examples/tutorial/sequence_parallel/data/dummy_dataloader.py b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..faa90175cc60fa06283d1ae87ae992569226005c --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py @@ -0,0 +1,39 @@ +import torch + + +class DummyDataloader(): + + def __init__(self, batch_size, vocab_size, seq_length): + self.batch_size = batch_size + self.vocab_size = vocab_size + self.seq_length = seq_length + self.step = 0 + + def generate(self): + tokens = torch.randint(low=0, high=self.vocab_size, size=( + self.batch_size, + self.seq_length, + )) + types = torch.randint(low=0, high=3, size=( + self.batch_size, + self.seq_length, + )) + sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) + loss_mask = torch.randint(low=0, high=2, size=( + self.batch_size, + self.seq_length, + )) + lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) + padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) + return dict(text=tokens, + types=types, + is_random=sentence_order, + loss_mask=loss_mask, + labels=lm_labels, + padding_mask=padding_mask) + + def __iter__(self): + return self + + def __next__(self): + return self.generate() \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py b/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df27f14247eba7b887f73ca1022f08b1597026dd --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .tokenizer import build_tokenizer + + +_TOKENIZER = None +_PADDED_VOCAB_SIZE = -1 + + +def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): + tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids) + global _TOKENIZER, _PADDED_VOCAB_SIZE + _TOKENIZER = tokenizer + _PADDED_VOCAB_SIZE = padded_vocab_size + + +def get_tokenizer(): + global _TOKENIZER + return _TOKENIZER + + +def get_padded_vocab_size(): + global _PADDED_VOCAB_SIZE + return _PADDED_VOCAB_SIZE diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py b/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..1be494793909e878132a659104a8ecc597f4552e --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenization.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abbreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3c923e8e76517b8a5244608d6c18d080082952 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron tokenizers.""" + +from abc import ABC +from abc import abstractmethod +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + +from .bert_tokenization import FullTokenizer as FullBertTokenizer + + +def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): + """Initialize tokenizer.""" + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + print('> building {} tokenizer ...'.format(tokenizer_type), + flush=True) + + # Select and instantiate the tokenizer. + if tokenizer_type == 'BertWordPieceLowerCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, + lower_case=True, + vocab_extra_ids=vocab_extra_ids) + elif tokenizer_type == 'BertWordPieceCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, + lower_case=False, + vocab_extra_ids=vocab_extra_ids) + else: + raise NotImplementedError('{} tokenizer is not ' + 'implemented.'.format(tokenizer_type)) + + # Add vocab size. + padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) + + return tokenizer, padded_vocab_size + + +def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + + if gpc.is_initialized(ParallelMode.TENSOR): + multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR) + else: + multiple = make_vocab_size_divisible_by + while (after % multiple) != 0: + after += 1 + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + print(' > padded vocab (size: {}) with {} dummy tokens ' + '(new size: {})'.format( + orig_vocab_size, after - orig_vocab_size, after), flush=True) + return after + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError('detokenizer is not implemented for {} ' + 'tokenizer'.format(self.name)) + + @property + def cls(self): + raise NotImplementedError('CLS is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def sep(self): + raise NotImplementedError('SEP is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def pad(self): + raise NotImplementedError('PAD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def eod(self): + raise NotImplementedError('EOD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def mask(self): + raise NotImplementedError('MASK is not provided for {} ' + 'tokenizer'.format(self.name)) + + +class _BertWordPieceTokenizer(AbstractTokenizer): + """Original BERT wordpiece tokenizer.""" + + def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): + if lower_case: + name = 'BERT Lower Case' + else: + name = 'BERT Upper Case' + super().__init__(name) + self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) + self.cls_id = self.tokenizer.vocab['[CLS]'] + self.sep_id = self.tokenizer.vocab['[SEP]'] + self.pad_id = self.tokenizer.vocab['[PAD]'] + self.mask_id = self.tokenizer.vocab['[MASK]'] + self._additional_special_tokens = [] + + # (dsachan) Add BOS and EOS tokens + SPECIAL_TOKENS = {'eos_token': '[EOS]', + 'bos_token': '[BOS]'} + self._bos_token = '[BOS]' + self.add_token(self._bos_token) + self._bos_token_id = self.vocab.get(self._bos_token) + + self._eos_token = '[EOS]' + self.add_token(self._eos_token) + self._eos_token_id = self.vocab.get(self._eos_token) + + # (dsachan) Add additional special tokens + # These can be used as sentinel tokens in T5 model inputs + additional_special_tokens = [] + additional_special_tokens.extend( + ["".format(i) for i in range(vocab_extra_ids)]) + self.add_additional_special_tokens(additional_special_tokens) + + def add_token(self, token): + if token not in self.vocab: + self.inv_vocab[self.vocab_size] = token + # self.vocab_size comes from len(vocab) + # and it will increase as we add elements + self.vocab[token] = self.vocab_size + + def add_additional_special_tokens(self, tokens_list): + setattr(self, "additional_special_tokens", tokens_list) + for value in tokens_list: + self.add_token(value) + + @property + def vocab_size(self): + return self.tokenizer.vocab_size() + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self.tokenizer.inv_vocab + + def tokenize(self, text): + text_tokens = self.tokenizer.tokenize(text) + return self.tokenizer.convert_tokens_to_ids(text_tokens) + + def decode(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode_token_ids(self, token_ids): + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + exclude_list = ['[PAD]', '[CLS]'] + non_pads = [t for t in tokens if t not in exclude_list] + + result = "" + for s in non_pads: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + + return result + + @property + def cls(self): + return self.cls_id + + @property + def sep(self): + return self.sep_id + + @property + def pad(self): + return self.pad_id + + @property + def mask(self): + return self.mask_id + + @property + def bos_token(self): + """ Beginning of sentence token id """ + return self._bos_token + + @property + def eos_token(self): + """ End of sentence token id """ + return self._eos_token + + @property + def additional_special_tokens(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self._additional_special_tokens + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary.""" + return self._bos_token_id + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary.""" + return self._eos_token_id + + @property + def additional_special_tokens_ids(self): + """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + return [self.vocab.get(token) for token in self._additional_special_tokens] + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value diff --git a/examples/tutorial/sequence_parallel/loss_func/__init__.py b/examples/tutorial/sequence_parallel/loss_func/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e87a778cf5d5959f68c134957f101787ca87cffe --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.logging import get_dist_logger +import torch.nn.functional as F +import torch.distributed as dist +from .cross_entropy import vocab_cross_entropy + + +class BertLoss(nn.Module): + + def forward(self, + lm_loss, + sop_logits, + loss_mask, + sentence_order): + lm_loss_ = lm_loss.float() + loss_mask = loss_mask.float() + loss_mask_sum = loss_mask.sum() + lm_loss = torch.sum( + lm_loss_.view(-1) * loss_mask.reshape(-1)) + + lm_loss /= loss_mask_sum + + torch.distributed.all_reduce( + lm_loss, + group=gpc.get_group(ParallelMode.SEQUENCE) + ) + + if sop_logits is not None: + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), + sentence_order.view(-1), + ignore_index=-1) + sop_loss = sop_loss.float() + loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) + else: + sop_loss = None + loss = lm_loss + + return loss diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..54553c29a61f91b34d9b192ab734ff201a92892e --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -0,0 +1,75 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _VocabCrossEntropy(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, vocab_parallel_logits, target): + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = target < 0 + masked_target = target.clone() + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= ( + 1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_cross_entropy(vocab_logits, target): + """helper function for the cross entropy.""" + + return _VocabCrossEntropy.apply(vocab_logits, target) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d92f294326312a6f3cb9be04ed38124b104485 --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -0,0 +1,55 @@ + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py b/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8b615bce63cd1750db9cb063261d51a52e8b30 --- /dev/null +++ b/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py @@ -0,0 +1 @@ +from .annealing_lr import AnnealingLR diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..8d95679ff76dcdab3020c02b30e9e8616d086c73 --- /dev/null +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learning rate decay functions.""" + +import math + + +class AnnealingLR(object): + """Anneals the learning rate.""" + + def __init__(self, + optimizer, + max_lr, + min_lr, + warmup_steps, + decay_steps, + decay_style, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False): + + # Class values. + self.optimizer = optimizer + + self.max_lr = float(max_lr) + self.min_lr = min_lr + assert self.min_lr >= 0.0 + assert self.max_lr >= self.min_lr + + self.warmup_steps = warmup_steps + self.num_steps = 0 + self.decay_steps = decay_steps + assert self.decay_steps > 0 + assert self.warmup_steps < self.decay_steps + + self.decay_style = decay_style + + self.override_lr_scheduler = override_lr_scheduler + self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler + if self.override_lr_scheduler: + assert not self.use_checkpoint_lr_scheduler, 'both override and '\ + 'use-checkpoint are set.' + + # Set the learning rate + self.step(0) + + def get_lr(self): + """Learning rate decay functions from: + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + + # Use linear warmup for the initial part. + if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: + return self.max_lr * float(self.num_steps) / \ + float(self.warmup_steps) + + # If the learning rate is constant, just return the initial value. + if self.decay_style == 'constant': + return self.max_lr + + # For any steps larger than `self.decay_steps`, use `self.min_lr`. + if self.num_steps > self.decay_steps: + return self.min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = self.num_steps - self.warmup_steps + decay_steps_ = self.decay_steps - self.warmup_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = self.max_lr - self.min_lr + + if self.decay_style == 'linear': + coeff = (1.0 - decay_ratio) + elif self.decay_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + else: + raise Exception('{} decay style is not supported.'.format( + self.decay_style)) + + return self.min_lr + coeff * delta_lr + + def step(self, increment=1): + """Set lr for all parameters groups.""" + self.num_steps += increment + new_lr = self.get_lr() + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + def state_dict(self): + state_dict = { + 'max_lr': self.max_lr, + 'warmup_steps': self.warmup_steps, + 'num_steps': self.num_steps, + 'decay_style': self.decay_style, + 'decay_steps': self.decay_steps, + 'min_lr': self.min_lr + } + return state_dict + + def _check_and_set(self, cls_value, sd_value, name): + """Auxiliary function for checking the values in the checkpoint and + setting them.""" + if self.override_lr_scheduler: + return cls_value + + if not self.use_checkpoint_lr_scheduler: + assert cls_value == sd_value, \ + f'AnnealingLR: class input value {cls_value} and checkpoint' \ + f'value {sd_value} for {name} do not match' + return sd_value + + def load_state_dict(self, sd): + + if 'start_lr' in sd: + max_lr_ = sd['start_lr'] + else: + max_lr_ = sd['max_lr'] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, + 'learning rate') + + self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], + 'minimum learning rate') + + if 'warmup_iter' in sd: + warmup_steps_ = sd['warmup_iter'] + else: + warmup_steps_ = sd['warmup_steps'] + self.warmup_steps = self._check_and_set(self.warmup_steps, + warmup_steps_, + 'warmup iterations') + + if 'end_iter' in sd: + decay_steps_ = sd['end_iter'] + else: + decay_steps_ = sd['decay_steps'] + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, + 'total number of iterations') + self.decay_style = self._check_and_set(self.decay_style, + sd['decay_style'], + 'decay style') + + if 'num_iters' in sd: + num_steps = sd['num_iters'] + else: + num_steps = sd['num_steps'] + self.step(increment=num_steps) diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..049579c5a639c2bab29348fe0290b75497dbdc31 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -0,0 +1,282 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +import torch.nn as nn +import inspect +from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding +from .layers.init_method import init_normal, output_init_normal +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.kernel import LayerNorm +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.logging import get_dist_logger +from colossalai.pipeline.utils import partition_uniform + + +class BertForPretrain(nn.Module): + + def __init__(self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + ): + super().__init__() + self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) + assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + self.sub_seq_length = max_sequence_length // self.seq_parallel_size + self.init_std = init_std + self.num_layers = num_layers + + if not add_binary_head: + num_tokentypes = 0 + + self.preprocessor = PreProcessor(self.sub_seq_length) + self.embedding = Embedding(hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes) + self.bert_layers = nn.ModuleList() + + for i in range(num_layers): + bert_layer = BertLayer(layer_number=i+1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16 + ) + self.bert_layers.append(bert_layer) + + self.layer_norm = LayerNorm(hidden_size) + self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0), + add_binary_head=add_binary_head) + self.reset_parameters() + + def _init_normal(self, tensor): + init_normal(tensor, sigma=self.init_std) + + def _output_init_normal(self, tensor): + output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers) + + def reset_parameters(self): + # initialize embedding + self._init_normal(self.embedding.word_embedding_weight) + self._init_normal(self.embedding.position_embeddings.weight) + if self.embedding.tokentype_embeddings: + self._init_normal(self.embedding.tokentype_embeddings.weight) + + # initialize bert layer + for layer in self.bert_layers: + # initialize self attention + self._init_normal(layer.self_attention.query_key_value.weight) + self._output_init_normal(layer.self_attention.dense.weight) + self._init_normal(layer.mlp.dense_h_to_4h.weight) + self._output_init_normal(layer.mlp.dense_4h_to_h.weight) + + # initializer head + self._init_normal(self.head.lm_head.dense.weight) + if self.head.binary_head is not None: + self._init_normal(self.head.binary_head.pooler.dense.weight) + self._init_normal(self.head.binary_head.dense.weight) + + def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): + # inputs of the forward function + # input_ids: [batch_size, sub_seq_len] + # attention_mask: [batch_size, seq_len] + # tokentype_ids: [batch_size, sub_seq_len] + # outputs of preprocessor + # pos_ids: [batch_size, sub_seq_len] + # attention_masks: [batch_size, 1, sub_seq_len, seq_len] + pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks) + + hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids) + + # hidden_states shape change: + # [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size] + hidden_states = hidden_states.transpose(0, 1).contiguous() + + for idx, layer in enumerate(self.bert_layers): + hidden_states = layer(hidden_states, attention_masks) + + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.layer_norm(hidden_states) + + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # word_embedding: [vocab_size, hidden_size] + return self.head(output, self.embedding.word_embedding_weight, lm_labels) + + +class PipelineBertForPretrain(nn.Module): + + def __init__(self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None): + super().__init__() + self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) + assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + self.sub_seq_length = max_sequence_length // self.seq_parallel_size + self.init_std = init_std + self.num_layers = num_layers + + if not add_binary_head: + num_tokentypes = 0 + + self.first_stage = first_stage + self.last_stage = last_stage + + self.preprocessor = PreProcessor(self.sub_seq_length) + + if self.first_stage: + self.embedding = Embedding(hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes) + + # transformer layers + self.bert_layers = nn.ModuleList() + + if start_idx is None and end_idx is None: + start_idx = 0 + end_idx = num_layers + + for i in range(start_idx, end_idx): + bert_layer = BertLayer(layer_number=i+1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16 + ) + self.bert_layers.append(bert_layer) + + if self.last_stage: + self.word_embeddings = VocabEmbedding(vocab_size, hidden_size) + self.layer_norm = LayerNorm(hidden_size) + self.head = BertDualHead(hidden_size, vocab_size, + add_binary_head=add_binary_head) + self.reset_parameters() + + def _init_normal(self, tensor): + init_normal(tensor, sigma=self.init_std) + + def _output_init_normal(self, tensor): + output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers) + + def reset_parameters(self): + # initialize embedding + if self.first_stage: + self._init_normal(self.embedding.word_embedding_weight) + self._init_normal(self.embedding.position_embeddings.weight) + if self.embedding.tokentype_embeddings: + self._init_normal(self.embedding.tokentype_embeddings.weight) + + # initialize bert layer + for layer in self.bert_layers: + # initialize self attention + self._init_normal(layer.self_attention.query_key_value.weight) + self._output_init_normal(layer.self_attention.dense.weight) + self._init_normal(layer.mlp.dense_h_to_4h.weight) + self._output_init_normal(layer.mlp.dense_4h_to_h.weight) + + # initializer head + if self.last_stage: + self._init_normal(self.head.lm_head.dense.weight) + if self.head.binary_head is not None: + self._init_normal(self.head.binary_head.pooler.dense.weight) + self._init_normal(self.head.binary_head.dense.weight) + + def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): + # inputs of the forward function + # input_ids: [batch_size, sub_seq_len] + # attention_mask: [batch_size, seq_len] + # tokentype_ids: [batch_size, sub_seq_len] + # outputs of preprocessor + # pos_ids: [batch_size, sub_seq_len] + # attention_masks: [batch_size, 1, sub_seq_len, seq_len] + if self.first_stage: + pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks) + else: + _, attention_masks = self.preprocessor(None, attention_masks) + + if self.first_stage: + hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids) + hidden_states = hidden_states.transpose(0, 1).contiguous() + else: + hidden_states = input_ids + + # hidden_states shape change: + # [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size] + for idx, layer in enumerate(self.bert_layers): + hidden_states = layer(hidden_states, attention_masks) + + if self.last_stage: + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.layer_norm(hidden_states) + output = self.head(output, self.word_embeddings.weight, lm_labels) + else: + output = hidden_states + + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # word_embedding: [vocab_size, hidden_size] + return output + + +def _filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['num_layers'] = num_layers + kwargs['start_idx'] = start + kwargs['end_idx'] = end + kwargs['first_stage'] = start == 0 + kwargs['last_stage'] = end == num_layers + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device) + if start == 0: + wrapper.register_module(chunk.embedding.word_embeddings) + elif end == num_layers: + wrapper.register_module(chunk.word_embeddings) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8823caa81b4e1f3cdbf9ea4f79e012a1b99505 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -0,0 +1,4 @@ +from .embedding import VocabEmbedding, Embedding +from .bert_layer import BertLayer +from .head import BertDualHead +from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ede21516f65a8bc008a51dea505811fb0ceda48 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing +from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from colossalai.kernel.cuda_native import LayerNorm +from .mlp import TransformerMLP +from .dropout import get_bias_dropout_add + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class BertLayer(nn.Module): + """A single transformer layer. + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__(self, + layer_number, + hidden_size, + num_attention_heads, + attention_dropout, + mlp_ratio, + hidden_dropout, + is_naive_fp16, + apply_residual_connection_post_layernorm=False, + fp32_residual_connection=False, + bias_dropout_fusion: bool = True, + convert_fp16_to_fp32_in_softmax: bool = False): + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.fp32_residual_connection = fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size) + + # Self attention. + self.self_attention = TransformerSelfAttentionRing( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + attention_mask_func=attention_mask_func, + layer_number=layer_number, + apply_query_key_layer_scaling=True, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + fp16=is_naive_fp16 + ) + + self.hidden_dropout = hidden_dropout + self.bias_dropout_fusion = bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm(hidden_size) + + self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio) + + def forward(self, hidden_states, attention_mask): + # hidden_states: [batch_size, sub_seq_len, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias.expand_as(residual), + residual, + self.hidden_dropout) + + return output diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..0e99105b8f7e9d1ce2f884eba12678d5b18f4468 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -0,0 +1,13 @@ +import torch + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0700d960d845ff1ab1f0258d6e174b88ad3ac902 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +class VocabEmbedding(torch.nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super(VocabEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + + # Allocate weights and initialize. + self.weight = nn.Parameter(torch.empty( + self.num_embeddings, self.embedding_dim)) + init.xavier_uniform_(self.weight) + + def forward(self, hidden_state): + output = F.embedding(hidden_state, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + return output + + def __repr__(self): + return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \ + f'embedding_dim={self.embedding_dim})' + + +class Embedding(nn.Module): + """Language model embeddings. + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + num_tokentypes): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.num_tokentypes = num_tokentypes + + self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + else: + self.tokentype_embeddings = None + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + @property + def word_embedding_weight(self): + return self.word_embeddings.weight + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + if tokentype_ids is not None and self.tokentype_embeddings is not None: + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + return embeddings diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py new file mode 100644 index 0000000000000000000000000000000000000000..ea336b9d131e3405dc22d03f3e7b51add4e54466 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -0,0 +1,78 @@ +import colossalai +import torch +import torch.nn as nn +import torch.nn.functional as F +from .pooler import Pooler +from .linear import Linear +from .embedding import VocabEmbedding +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.kernel import LayerNorm +from loss_func.cross_entropy import vocab_cross_entropy + + +class BertLMHead(nn.Module): + """Masked LM head for Bert + Arguments: + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + """ + + def __init__(self, + vocab_size, + hidden_size, + ): + + super(BertLMHead, self).__init__() + self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) + + self.dense = Linear(hidden_size, hidden_size) + self.layernorm = LayerNorm(hidden_size) + self.gelu = torch.nn.functional.gelu + + def forward(self, hidden_states, word_embeddings_weight, lm_labels): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + + output = F.linear(hidden_states, word_embeddings_weight, self.bias) + lm_loss = vocab_cross_entropy(output, lm_labels) + + return lm_loss + + +class BertBinaryHead(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.pooler = Pooler(hidden_size) + self.dense = Linear(hidden_size, 2) + + def forward(self, hidden_states): + if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0: + output = self.pooler(hidden_states) + output = self.dense(output) + else: + output = None + return output + + +class BertDualHead(nn.Module): + + def __init__(self, hidden_size, vocab_size, add_binary_head): + super().__init__() + self.lm_head = BertLMHead(vocab_size, hidden_size) + self.add_binary_head = add_binary_head + if add_binary_head: + self.binary_head = BertBinaryHead(hidden_size) + else: + self.binary_head = None + + def forward(self, hidden_states, word_embeddings_weight, lm_labels): + if self.add_binary_head: + binary_output = self.binary_head(hidden_states) + else: + binary_output = None + lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels) + return lm_loss, binary_output diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py new file mode 100644 index 0000000000000000000000000000000000000000..1b409dfe40541524891f70fc7c7d8297afa86999 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -0,0 +1,12 @@ +import torch +import math + +def init_normal(tensor, sigma): + """Init method based on N(0, sigma).""" + torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + +def output_init_normal(tensor, sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + torch.nn.init.normal_(tensor, mean=0.0, std=std) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae7d671e2bf2312da315d35629dcdc12ca075de --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +import torch.nn.init as init + + +class Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, + input_size, + output_size, + bias=True, + skip_bias_add=False): + super(Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter(torch.empty(self.output_size, + self.input_size, + )) + init.normal_(self.weight) + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ + f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..a255de813d135e5c79a281c86462340337c5c036 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .linear import Linear +from colossalai.kernel.jit import bias_gelu_impl + + +class TransformerMLP(nn.Module): + """MLP. + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ + + def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True): + super(TransformerMLP, self).__init__() + + # Project to 4h. + self.dense_h_to_4h = Linear( + hidden_size, + int(hidden_size*mlp_ratio), + skip_bias_add=True) + + self.bias_gelu_fusion = fuse_gelu + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = Linear( + int(hidden_size*mlp_ratio), + hidden_size, + skip_bias_add=True) + + def forward(self, hidden_states): + # hidden states should be in the shape of [s, b, h] + # it will be projects into [s, b, 4h] + # and projected back to [s, b, h] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py new file mode 100644 index 0000000000000000000000000000000000000000..282ed114790b32618c1a92924bce403167d1b89e --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +from .linear import Linear + + +class Pooler(nn.Module): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size): + super(Pooler, self).__init__() + self.dense = Linear(hidden_size, hidden_size) + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [b, s, h] + # sequence_index: index of the token to pool. + pooled = hidden_states[:, sequence_index, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..53a326ddacf14f7babaa18390a04b962bd84c809 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -0,0 +1,58 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +import torch.nn as nn +from colossalai.core import global_context as gpc + + +class PreProcessor(nn.Module): + + def __init__(self, sub_seq_length): + super().__init__() + self.sub_seq_length = sub_seq_length + + def bert_position_ids(self, token_ids): + # Create position ids + seq_length = token_ids.size(1) + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + position_ids = torch.arange(seq_length*local_rank, + seq_length * (local_rank+1), + dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + def bert_extended_attention_mask(self, attention_mask): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + start_index = local_rank * self.sub_seq_length + end_index = (local_rank + 1) * self.sub_seq_length + + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s/D, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + + attention_mask_bss = attention_mask_bss[:, start_index:end_index, :] + + # [b, 1, s/D, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + + def forward(self, input_ids=None, attention_mask=None): + if attention_mask is not None: + extended_attention_mask = self.bert_extended_attention_mask(attention_mask) + else: + extended_attention_mask = None + + if input_ids is not None: + position_ids = self.bert_position_ids(input_ids) + else: + position_ids = None + return position_ids, extended_attention_mask diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b92061000d10229cda90c123c78ba6909a692401 --- /dev/null +++ b/examples/tutorial/sequence_parallel/train.py @@ -0,0 +1,240 @@ +import argparse + +import torch +from data import build_train_valid_test_data_iterators +from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel +from data.tokenizer import get_padded_vocab_size, initialize_tokenizer +from loss_func.bert_loss import BertLoss +from lr_scheduler import AnnealingLR +from model.bert import BertForPretrain, build_pipeline_bert + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import PipelineSchedule +from colossalai.kernel import LayerNorm +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import FusedAdam +from colossalai.utils import MultiTimer, is_using_pp + + +def process_batch_data(batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels) + else: + data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels) + label = dict(loss_mask=loss_mask, sentence_order=sentence_order) + return data, label + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + return parser.parse_args() + + +def pipeline_data_process_func(stage_output, micro_batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = (tokens, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + else: + data = (stage_output, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + return data, label + + +def main(): + # initialize + args = parse_args() + colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') + + logger = get_dist_logger() + + # build dataloader + if not args.synthetic: + initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase') + VOCAB_SIZE = get_padded_vocab_size() + trainloader, validloader, testloader = build_train_valid_test_data_iterators( + train_iters=gpc.config.TRAIN_ITERS, + global_batch_size=gpc.config.GLOBAL_BATCH_SIZE, + eval_interval=gpc.config.EVAL_INTERVAL, + eval_iters=gpc.config.EVAL_ITERS, + data_prefix=[gpc.config.DATA_PATH], + data_impl='mmap', + splits_string='949,50,1', + max_seq_length=gpc.config.SEQ_LENGTH, + masked_lm_prob=0.15, + short_seq_prob=0.1, + seed=1234, + skip_warmup=True, + binary_head=False, + ) + else: + from data.dummy_dataloader import DummyDataloader + + BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) + VOCAB_SIZE = 30528 + trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, + vocab_size=VOCAB_SIZE, + seq_length=gpc.config.SEQ_LENGTH) + validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, + vocab_size=VOCAB_SIZE, + seq_length=gpc.config.SEQ_LENGTH) + + logger.info("Dataloaders are built", ranks=[0]) + + # build model + if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE: + is_naive_fp16 = True + else: + is_naive_fp16 = False + + use_pipeline = is_using_pp() + kwargs = dict(vocab_size=VOCAB_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + max_sequence_length=gpc.config.SEQ_LENGTH, + num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, + convert_fp16_to_fp32_in_softmax=True, + is_naive_fp16=is_naive_fp16, + add_binary_head=gpc.config.ADD_BINARY_HEAD) + + if use_pipeline: + model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) + else: + model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs) + + model = model.half() + model.reset_parameters() + logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0]) + + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + logger.info(f"This model has {total_numel} parameters") + + # build criterion + criterion = BertLoss() + logger.info("Criterion is built", ranks=[0]) + + # layernorm and bias has no weight decay + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for module_ in model.modules(): + if isinstance(module_, LayerNorm): + no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None]) + else: + weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) + no_weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) + + logger.info( + f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}" + ) + # optimizer + optimizer = FusedAdam((weight_decay_params, no_weight_decay_params), + lr=gpc.config.LR, + weight_decay=gpc.config.WEIGHT_DECAY) + logger.info("Optimizer is built", ranks=[0]) + + # lr scheduler + # follow Megatron-LM setting + warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION) + lr_scheduler = AnnealingLR(optimizer=optimizer, + max_lr=gpc.config.LR, + min_lr=gpc.config.MIN_LR, + warmup_steps=warmup_steps, + decay_steps=gpc.config.DECAY_ITERS, + decay_style='linear') + logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") + + # # init + engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True) + + # build timer + timer = MultiTimer() + skip_iters = 0 + + # build loss tracker + accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda() + accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda() + + # build data iters for pipeline parallel + if use_pipeline: + train_data_iter = SequenceParallelDataIterator(trainloader) + valid_data_iter = SequenceParallelDataIterator(validloader) + engine.schedule.data_process_func = pipeline_data_process_func + + logger.info("start training") + + for step in range(1, gpc.config.TRAIN_ITERS + 1): + timer.start('train-iterations') + engine.train() + if use_pipeline: + engine.zero_grad() + _, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False) + engine.step() + else: + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( + trainloader) + engine.zero_grad() + lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) + train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) + engine.backward(train_loss) + engine.step() + timer.stop('train-iterations', keep_in_history=True) + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_train_loss += train_loss + + lr_scheduler.step() + + if step % gpc.config.EVAL_INTERVAL == 0: + engine.eval() + + for j in range(gpc.config.EVAL_ITERS): + with torch.no_grad(): + if use_pipeline: + _, _, eval_loss = engine.execute_schedule(valid_data_iter, + forward_only=True, + return_output_label=False) + else: + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( + validloader) + lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) + eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_eval_loss += eval_loss + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_eval_loss /= gpc.config.EVAL_ITERS + accumulated_train_loss /= gpc.config.EVAL_INTERVAL + + timer_string = [] + for n, t in timer: + timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}") + timer_string = ' | '.join(timer_string) + lr = list(engine.optimizer.param_groups)[0]['lr'] + loss_scale = engine.optimizer.optim.loss_scale.item() + + if gpc.is_initialized(ParallelMode.PIPELINE): + ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]] + else: + ranks = [0] + logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' + + f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' + + f"| Learning rate: {lr} | " + timer_string, + ranks=ranks) + + for n, t in timer: + t.reset() + accumulated_eval_loss.zero_() + accumulated_train_loss.zero_() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/stable_diffusion/LICENSE b/examples/tutorial/stable_diffusion/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0e609df0d8cd3b5d11a1ea962a56b604b70846a5 --- /dev/null +++ b/examples/tutorial/stable_diffusion/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensorsโ€™ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individualโ€™s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/examples/tutorial/stable_diffusion/README.md b/examples/tutorial/stable_diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a0ece4485d27ff32853c1eabf5195b3edaf295fc --- /dev/null +++ b/examples/tutorial/stable_diffusion/README.md @@ -0,0 +1,149 @@ +# Stable Diffusion with Colossal-AI +*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and +fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).* + +We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies +, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs. + +## ๐Ÿš€Quick Start +1. Create a new environment for diffusion +```bash +conda env create -f environment.yaml +conda activate ldm +``` +2. Install Colossal-AI from our official page +```bash +pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +``` +3. Install PyTorch Lightning compatible commit +```bash +git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa +pip install -r requirements.txt && pip install . +cd .. +``` + +4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`. +5. Run training with CIFAR10. +```bash +python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml +``` + +## Stable Diffusion +[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion +model. +Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. +Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487), +this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. + +

+ +

+ +[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB). + +

+ +

+ +## Requirements +A suitable [conda](https://conda.io/) environment named `ldm` can be created +and activated with: + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running + +``` +conda install pytorch torchvision -c pytorch +pip install transformers==4.19.2 diffusers invisible-watermark +pip install -e . +``` + +### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website +``` +pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +``` + +### Install [Lightning](https://github.com/Lightning-AI/lightning) +We use the Sep. 2022 version with commit id as `b04a7aa`. +``` +git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa +pip install -r requirements.txt && pip install . +``` + +> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future. + +## Dataset +The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), +you should the change the `data.file_path` in the `config/train_colossalai.yaml` + +## Training + +We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` + +For example, you can run the training from colossalai by +``` +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml +``` + +- you can change the `--logdir` the save the log information and the last checkpoint + +### Training config +You can change the trainging config in the yaml file + +- accelerator: acceleratortype, default 'gpu' +- devices: device number used for training, default 4 +- max_epochs: max training epochs +- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai + +## Example + +### Training on cifar10 + +We provide the finetuning example on CIFAR10 dataset + +You can run by config `train_colossalai_cifar10.yaml` +``` +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml +``` + + + +## Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch), +[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + +- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch). + +## BibTeX + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Bjรถrn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +@article{dao2022flashattention, + title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} +``` diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c457787dd881f7aa60ff785abe6e810d62f45b68 --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml @@ -0,0 +1,116 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: False + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: "/data/scratch/diffuser/laion_part0/" + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 4 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63b9d1c0179c8842b95cda5f0ddbe54f04b7c1f9 --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml @@ -0,0 +1,123 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 4 + train: + target: ldm.data.cifar10.hf_dataset + params: + name: cifar10 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/configs/train_ddp.yaml b/examples/tutorial/stable_diffusion/configs/train_ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..90d41258fada7eaa6f7fe478a1be7dda08b286ca --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_ddp.yaml @@ -0,0 +1,113 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 32 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 100 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: False + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: "/data/scratch/diffuser/laion_part0/" + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 4 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: False + log_every_n_steps: 2 +# max_steps: 6o + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml b/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b5d2adfaf17efb73c01714a212ce9a7bfb37e94 --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml @@ -0,0 +1,121 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 32 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + check_nan_inf: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 32 + wrap: False + train: + target: ldm.data.pokemon.PokemonDataset + # params: + # file_path: "/data/scratch/diffuser/laion_part0/" + # world_size: 1 + # rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 4 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + initial_scale: 65536 + min_scale: 1 + max_scale: 65536 + # max_scale: 4294967296 + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/environment.yaml b/examples/tutorial/stable_diffusion/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59baa3c76162f63aa6a2a793e62f32879d459936 --- /dev/null +++ b/examples/tutorial/stable_diffusion/environment.yaml @@ -0,0 +1,33 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - python=3.9.12 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.19.2 + - pip: + - albumentations==0.4.3 + - datasets + - diffusers + - opencv-python==4.6.0.66 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.8.0 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.19.2 + - torchmetrics==0.7.0 + - kornia==0.6 + - prefetch_generator + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e . diff --git a/examples/tutorial/stable_diffusion/ldm/data/__init__.py b/examples/tutorial/stable_diffusion/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/stable_diffusion/ldm/data/base.py b/examples/tutorial/stable_diffusion/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3cd35714a02d087e0a19ffd5f91ff514689ab9 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/base.py @@ -0,0 +1,75 @@ +import math +from abc import abstractmethod + +import torch +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset +import os +import numpy as np +import cv2 + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, file_path: str, rank, world_size): + super().__init__() + self.file_path = file_path + self.folder_list = [] + self.file_list = [] + self.txt_list = [] + self.info = self._get_file_info(file_path) + self.start = self.info['start'] + self.end = self.info['end'] + self.rank = rank + + self.world_size = world_size + # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size))) + # self.iter_start = self.start + self.rank * self.per_worker + # self.iter_end = min(self.iter_start + self.per_worker, self.end) + # self.num_records = self.iter_end - self.iter_start + # self.valid_ids = [i for i in range(self.iter_end)] + self.num_records = self.end - self.start + self.valid_ids = [i for i in range(self.end)] + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + # return self.iter_end - self.iter_start + return self.end - self.start + + def __iter__(self): + sample_iterator = self._sample_generator(self.start, self.end) + # sample_iterator = self._sample_generator(self.iter_start, self.iter_end) + return sample_iterator + + def _sample_generator(self, start, end): + for idx in range(start, end): + file_name = self.file_list[idx] + txt_name = self.txt_list[idx] + f_ = open(txt_name, 'r') + txt_ = f_.read() + f_.close() + image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.from_numpy(image) / 255 + yield {"caption": txt_, "image":image} + + + def _get_file_info(self, file_path): + info = \ + { + "start": 1, + "end": 0, + } + self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + for folder in self.folder_list: + files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] + txts = [k.replace('jpg', 'txt') for k in files] + self.file_list.extend(files) + self.txt_list.extend(txts) + info['end'] = len(self.file_list) + # with open(file_path, 'r') as fin: + # for _ in enumerate(fin): + # info['end'] += 1 + # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list] + return info \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/data/cifar10.py b/examples/tutorial/stable_diffusion/ldm/data/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..53cd61263b472d37c6b2b896cfd8ba2f89477b9a --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/cifar10.py @@ -0,0 +1,184 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + name, + image_transforms=[], + image_column="img", + label_column="label", + text_column="txt", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + + label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + + processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] + + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/data/imagenet.py b/examples/tutorial/stable_diffusion/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/examples/tutorial/stable_diffusion/ldm/data/lsun.py b/examples/tutorial/stable_diffusion/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py b/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py b/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..873d8b69bd22c182dfedf528b369f063cb4c3152 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py @@ -0,0 +1,544 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + from_pretrained: str=None + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + from diffusers.modeling_utils import load_state_dict + if from_pretrained is not None: + state_dict = load_state_dict(from_pretrained) + self._load_pretrained_model(state_dict) + + def _state_key_mapping(self, state_dict: dict): + import re + res_dict = {} + key_list = state_dict.keys() + key_str = " ".join(key_list) + up_block_pattern = re.compile('upsamplers') + p1 = re.compile('mid.block_[0-9]') + p2 = re.compile('decoder.up.[0-9]') + up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1) + for key_, val_ in state_dict.items(): + key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\ + .replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\ + .replace('mid.attentions.0.key', 'mid.attn_1.k')\ + .replace('mid.attentions.0.query', 'mid.attn_1.q') \ + .replace('mid.attentions.0.value', 'mid.attn_1.v') \ + .replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \ + .replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\ + .replace('upsamplers.0', 'upsample')\ + .replace('downsamplers.0', 'downsample')\ + .replace('conv_shortcut', 'nin_shortcut')\ + .replace('conv_norm_out', 'norm_out') + + mid_list = re.findall(p1, key_) + if len(mid_list) != 0: + mid_str = mid_list[0] + mid_id = int(mid_str[-1]) + 1 + key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id)) + + up_list = re.findall(p2, key_) + if len(up_list) != 0: + up_str = up_list[0] + up_id = up_blocks_count - 1 -int(up_str[-1]) + key_ = key_.replace(up_str, up_str[:-1] + str(up_id)) + res_dict[key_] = val_ + return res_dict + + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): + state_dict = self._state_key_mapping(state_dict) + model_state_dict = self.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = self._load_state_dict_into_model(state_dict) + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def _load_state_dict_into_model(self, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self) + + return error_msgs + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/__init__.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..91335d6372df6d84b3f13297d980fe7fc5b635aa --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py @@ -0,0 +1,240 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..9633ec3d843a7df072960d81cc3483155cad4114 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1554 @@ +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d +from ldm.modules.x_transformer import * +from ldm.modules.encoders.modules import * + +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import * +from ldm.models.diffusion.ddim import * +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.model import * + + +from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder + +from ldm.util import instantiate_from_config + +from einops import rearrange, repeat + + + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + use_fp16 = True, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.unet_config = unet_config + self.conditioning_key = conditioning_key + # self.model = DiffusionWrapper(unet_config, conditioning_key) + # count_params(self.model, verbose=True) + self.use_ema = use_ema + # if self.use_ema: + # self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.ckpt_path = ckpt_path + self.ignore_keys = ignore_keys + self.load_only_unet = load_only_unet + self.given_betas = given_betas + self.beta_schedule = beta_schedule + self.timesteps = timesteps + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + # if ckpt_path is not None: + # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + # + # self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + # linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar_init = logvar_init + # self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + # if self.learn_logvar: + # self.logvar = nn.Parameter(self.logvar, requires_grad=True) + # self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.use_fp16 = use_fp16 + if use_fp16: + self.unet_config["params"].update({"use_fp16": True}) + rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"])) + else: + self.unet_config["params"].update({"use_fp16": False}) + rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"])) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + + if pred.isnan().any(): + print("Warning: Prediction has nan values") + lr = self.optimizers().param_groups[0]['lr'] + # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + print(f"lr: {lr}") + if pred.isinf().any(): + print("Warning: Prediction has inf values") + + if self.use_fp16: + target = target.half() + + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + if loss.isnan().any(): + print("Warning: loss has nan values") + print("loss: ", loss[0][0][0]) + raise ValueError("loss has nan values") + if loss.isinf().any(): + print("Warning: loss has inf values") + print("loss: ", loss) + raise ValueError("loss has inf values") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # print("+" * 30) + # print(batch['jpg'].shape) + # print(len(batch['txt'])) + # print(k) + # print("=" * 30) + if not isinstance(batch, torch.Tensor): + x = batch[k] + else: + x = batch + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + + if self.use_fp16: + x = x.to(memory_format=torch.contiguous_format).float().half() + else: + x = x.to(memory_format=torch.contiguous_format).float() + + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, use_fp16=use_fp16, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + if self.use_fp16: + self.cond_stage_config["params"].update({"use_fp16": True}) + rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"])) + else: + self.cond_stage_config["params"].update({"use_fp16": False}) + rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"])) + # self.instantiate_first_stage(first_stage_config) + # self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + + + def configure_sharded_model(self) -> None: + self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) + count_params(self.model, verbose=True) + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + + self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s) + + self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + # self.logvar = nn.Parameter(self.logvar, requires_grad=True) + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + + # TODO() + # for p in self.model.modules(): + # if not p.parameters().data.is_contiguous: + # p.data = p.data.contiguous() + + self.instantiate_first_stage(self.first_stage_config) + self.instantiate_cond_stage(self.cond_stage_config) + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + + + @rank_zero_only + @torch.no_grad() + # def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', 'txt']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) + # opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + rank_zero_info("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..78eeb1003aa45d27bdbfc6b4a1d7ccbff57cd2e3 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/examples/tutorial/stable_diffusion/ldm/modules/attention.py b/examples/tutorial/stable_diffusion/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3401ceafddb4f2338b7bd265fc1302c25a05bd29 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/attention.py @@ -0,0 +1,314 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from torch.utils import checkpoint + +try: + from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv + FlASH_AVAILABLE = True +except: + FlASH_AVAILABLE = False + +USE_FLASH = False + + +def enable_flash_attention(): + global USE_FLASH + USE_FLASH = True + if FlASH_AVAILABLE is False: + print("Please install flash attention to activate new attention kernel.\n" + + "Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'") + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + dim_head = q.shape[-1] / self.heads + + if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \ + dim_head <= 128 and (dim_head % 8) == 0: + # print("in flash") + if q.shape[1] == k.shape[1]: + out = self._flash_attention_qkv(q, k, v) + else: + out = self._flash_attention_q_kv(q, k, v) + else: + out = self._native_attention(q, k, v, self.heads, mask) + + return self.to_out(out) + + def _native_attention(self, q, k, v, h, mask): + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + # attention, what we cannot get enough of + out = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', out, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return out + + def _flash_attention_qkv(self, q, k, v): + qkv = torch.stack([q, k, v], dim=2) + b = qkv.shape[0] + n = qkv.shape[1] + qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads) + out = flash_attention_qkv(qkv, self.scale, b, n) + out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) + return out + + def _flash_attention_q_kv(self, q, k, v): + kv = torch.stack([k, v], dim=2) + b = q.shape[0] + q_seqlen = q.shape[1] + kv_seqlen = kv.shape[1] + q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads) + kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads) + out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen) + out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) + return out + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.use_checkpoint = use_checkpoint + + def forward(self, x, context=None): + + + if self.use_checkpoint: + return checkpoint(self._forward, x, context) + else: + return self._forward(x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, use_checkpoint=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + x = x.contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = x.contiguous() + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3c28492c550281a0a6596b94de124cac62181e6d --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,862 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + +class temb_module(nn.Module): + def __init__(self): + super().__init__() + pass + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + # self.temb = nn.Module() + self.temb = temb_module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # down = nn.Module() + down = Down_module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + # self.mid = nn.Module() + self.mid = Mid_module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # up = nn.Module() + up = Up_module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + +class Down_module(nn.Module): + def __init__(self): + super().__init__() + pass + +class Up_module(nn.Module): + def __init__(self): + super().__init__() + pass + +class Mid_module(nn.Module): + def __init__(self): + super().__init__() + pass + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # down = nn.Module() + down = Down_module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + # self.mid = nn.Module() + self.mid = Mid_module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + # self.mid = nn.Module() + self.mid = Mid_module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # up = nn.Module() + up = Up_module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..3aedc2205e134eec0065d4a4b748a448739b8a93 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1152 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import checkpoint + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + # for n,p in x.named_parameter(): + # print(f"convert module {n} to_f16") + # p.data = p.data.half() + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.use_checkpoint: + return checkpoint(self._forward, x, emb) + else: + return self._forward(x, emb) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + if self.use_checkpoint: + return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + else: + return self._forward(x) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + from_pretrained: str=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + # if use_fp16: + # self.convert_to_fp16() + from diffusers.modeling_utils import load_state_dict + if from_pretrained is not None: + state_dict = load_state_dict(from_pretrained) + self._load_pretrained_model(state_dict) + + def _input_blocks_mapping(self, input_dict): + res_dict = {} + for key_, value_ in input_dict.items(): + id_0 = int(key_[13]) + if "resnets" in key_: + id_1 = int(key_[23]) + target_id = 3 * id_0 + 1 + id_1 + post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\ + .replace('norm1', 'in_layers.0')\ + .replace('norm2', 'out_layers.0')\ + .replace('conv1', 'in_layers.2')\ + .replace('conv2', 'out_layers.3')\ + .replace('conv_shortcut', 'skip_connection') + res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ + elif "attentions" in key_: + id_1 = int(key_[26]) + target_id = 3 * id_0 + 1 + id_1 + post_fix = key_[28:] + res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_ + elif "downsamplers" in key_: + post_fix = key_[35:] + target_id = 3 * (id_0 + 1) + res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_ + return res_dict + + + def _mid_blocks_mapping(self, mid_dict): + res_dict = {} + for key_, value_ in mid_dict.items(): + if "resnets" in key_: + temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \ + .replace('norm1', 'in_layers.0') \ + .replace('norm2', 'out_layers.0') \ + .replace('conv1', 'in_layers.2') \ + .replace('conv2', 'out_layers.3') \ + .replace('conv_shortcut', 'skip_connection')\ + .replace('middle_block.resnets.0', 'middle_block.0')\ + .replace('middle_block.resnets.1', 'middle_block.2') + res_dict[temp_key_] = value_ + elif "attentions" in key_: + res_dict[key_.replace('attentions.0', '1')] = value_ + return res_dict + + def _other_blocks_mapping(self, other_dict): + res_dict = {} + for key_, value_ in other_dict.items(): + tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\ + .replace('time_embedding.linear_1', 'time_embed.0')\ + .replace('time_embedding.linear_2', 'time_embed.2')\ + .replace('conv_norm_out', 'out.0')\ + .replace('conv_out', 'out.2') + res_dict[tmp_key] = value_ + return res_dict + + + def _output_blocks_mapping(self, output_dict): + res_dict = {} + for key_, value_ in output_dict.items(): + id_0 = int(key_[14]) + if "resnets" in key_: + id_1 = int(key_[24]) + target_id = 3 * id_0 + id_1 + post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \ + .replace('norm1', 'in_layers.0') \ + .replace('norm2', 'out_layers.0') \ + .replace('conv1', 'in_layers.2') \ + .replace('conv2', 'out_layers.3') \ + .replace('conv_shortcut', 'skip_connection') + res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_ + elif "attentions" in key_: + id_1 = int(key_[27]) + target_id = 3 * id_0 + id_1 + post_fix = key_[29:] + res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_ + elif "upsamplers" in key_: + post_fix = key_[34:] + target_id = 3 * (id_0 + 1) - 1 + mid_str = '.2.conv.' if target_id != 2 else '.1.conv.' + res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_ + return res_dict + + def _state_key_mapping(self, state_dict: dict): + import re + res_dict = {} + input_dict = {} + mid_dict = {} + output_dict = {} + other_dict = {} + for key_, value_ in state_dict.items(): + if "down_blocks" in key_: + input_dict[key_.replace('down_blocks', 'input_blocks')] = value_ + elif "up_blocks" in key_: + output_dict[key_.replace('up_blocks', 'output_blocks')] = value_ + elif "mid_block" in key_: + mid_dict[key_.replace('mid_block', 'middle_block')] = value_ + else: + other_dict[key_] = value_ + + input_dict = self._input_blocks_mapping(input_dict) + output_dict = self._output_blocks_mapping(output_dict) + mid_dict = self._mid_blocks_mapping(mid_dict) + other_dict = self._other_blocks_mapping(other_dict) + # key_list = state_dict.keys() + # key_str = " ".join(key_list) + + # for key_, val_ in state_dict.items(): + # key_ = key_.replace("down_blocks", "input_blocks")\ + # .replace("up_blocks", 'output_blocks') + # res_dict[key_] = val_ + res_dict.update(input_dict) + res_dict.update(output_dict) + res_dict.update(mid_dict) + res_dict.update(other_dict) + + return res_dict + + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): + state_dict = self._state_key_mapping(state_dict) + model_state_dict = self.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = self._load_state_dict_into_model(state_dict) + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def _load_state_dict_into_model(self, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self) + + return error_msgs + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(self.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(self.dtype) + return self.out(h) + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a7db9369c58ae5adf99fe714a4725e6adf350089 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,276 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + if use_fp16: + return embedding.half() + else: + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels, precision=16): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + if precision == 16: + return GroupNorm16(16, channels) + else: + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + +class GroupNorm16(nn.GroupNorm): + def forward(self, x): + return super().forward(x.half()).type(x.dtype) + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/examples/tutorial/stable_diffusion/ldm/modules/ema.py b/examples/tutorial/stable_diffusion/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8cfc01e5ded4b038341b6bc7d33246c6f9e70d39 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py @@ -0,0 +1,264 @@ +import types + +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig +import kornia +from transformers.models.clip.modeling_clip import CLIPTextTransformer + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class CLIPTextModelZero(CLIPTextModel): + config_class = CLIPTextConfig + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformerZero(config) + +class CLIPTextTransformerZero(CLIPTextTransformer): + def _build_causal_attention_mask(self, bsz, seq_len): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask.half() + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + + if use_fp16: + self.transformer = CLIPTextModelZero.from_pretrained(version) + else: + self.transformer = CLIPTextModel.from_pretrained(version) + + # print(self.transformer.modules()) + # print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype)) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + # tokens = batch_encoding["input_ids"].to(self.device) + tokens = batch_encoding["input_ids"].to(self.device) + # print("token type: {}".format(tokens.dtype)) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py b/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7a738798579204b44611516579ee47b36639b4 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py @@ -0,0 +1,50 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) +""" + +import torch +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func +except ImportError: + raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention') + + + +def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len): + """ + Arguments: + qkv: (batch*seq, 3, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (total, nheads, headdim). + """ + max_s = seq_len + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, + device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, 0.0, + softmax_scale=sm_scale, causal=False + ) + return out + + +def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen): + """ + Arguments: + q: (batch*seq, nheads, headdim) + kv: (batch*seq, 2, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale) + return out diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1f823996bf559e9b015ea9aa2b3cd38dd13af1 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png differ diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2๏ผŒcmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py b/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/examples/tutorial/stable_diffusion/ldm/util.py b/examples/tutorial/stable_diffusion/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba38853e7a07228cc2c187742b5c45d7359b3f9 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/examples/tutorial/stable_diffusion/main.py b/examples/tutorial/stable_diffusion/main.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd00e4c0c264e1f6323e027edf72e1fd76decc3 --- /dev/null +++ b/examples/tutorial/stable_diffusion/main.py @@ -0,0 +1,830 @@ +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image +# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy +# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from prefetch_generator import BackgroundGenerator + +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info +from diffusers.models.unet_2d import UNet2DModel + +from clip.model import Bottleneck +from transformers.models.clip.modeling_clip import CLIPTextTransformer + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import * +from ldm.modules.encoders.modules import * +from taming.modules.diffusionmodules.model import ResnetBlock +from taming.modules.transformer.mingpt import * +from taming.modules.transformer.permuter import * + + +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import AutoencoderKL +from ldm.models.autoencoder import * +from ldm.models.diffusion.ddim import * +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.model import * +from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module +from ldm.modules.attention import enable_flash_attention + +class DataLoaderX(DataLoader): + + def __iter__(self): + return BackgroundGenerator(super().__iter__()) + + +def get_parser(**parser_kwargs): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "-p", + "--project", + help="name of new or path to existing project" + ) + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "--use_fp16", + type=str2bool, + nargs="?", + const=True, + default=True, + help="whether to use fp16", + ) + parser.add_argument( + "--flash", + type=str2bool, + const=True, + default=False, + nargs="?", + help="whether to use flash attention", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, + wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict( + (k, instantiate_from_config(self.dataset_configs[k])) + for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["train"], batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoaderX(self.datasets["test"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn) + + +class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + # def on_pretrain_routine_start(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.CSVLogger: self._testtube, + } + self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + pl_module.logger.experiment.add_image( + tag, grid, + global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, + global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( + k, + global_step, + current_epoch, + batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + # self.log_img(pl_module, batch, batch_idx, split="train") + pass + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + + def on_train_start(self, trainer, pl_module): + rank_zero_info("Training is starting") + + def on_train_end(self, trainer, pl_module): + rank_zero_info("Training is ending") + + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.strategy.reduce(max_memory) + epoch_time = trainer.strategy.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + # custom parser to specify config files, train, test and debug mode, + # postfix, resume. + # `--key value` arguments are interpreted as arguments to the trainer. + # `nested.key=value` arguments are interpreted as config parameters. + # configs are merged from left-to-right followed by command line parameters. + + # model: + # base_learning_rate: float + # target: path to lightning module + # params: + # key: value + # data: + # target: main.DataModuleFromConfig + # params: + # batch_size: int + # wrap: bool + # train: + # target: path to train dataset + # params: + # key: value + # validation: + # target: path to validation dataset + # params: + # key: value + # test: + # target: path to test dataset + # params: + # key: value + # lightning: (optional, has sane defaults and can be specified on cmdline) + # trainer: + # additional arguments to trainer + # logger: + # logger to instantiate + # modelcheckpoint: + # modelcheckpoint to instantiate + # callbacks: + # callback1: + # target: importpath + # params: + # key: value + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + if opt.name and opt.resume: + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) + if opt.flash: + enable_flash_attention() + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + seed_everything(opt.seed) + + try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + + print(trainer_config) + if not trainer_config["accelerator"] == "gpu": + del trainer_config["accelerator"] + cpu = True + print("Running on CPU") + else: + cpu = False + print("Running on GPU") + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + use_fp16 = trainer_config.get("precision", 32) == 16 + if use_fp16: + config.model["params"].update({"use_fp16": True}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + else: + config.model["params"].update({"use_fp16": False}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + + model = instantiate_from_config(config.model) + # trainer and callbacks + trainer_kwargs = dict() + + # config the logger + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + }, + "tensorboard":{ + "target": "pytorch_lightning.loggers.TensorBoardLogger", + "params":{ + "save_dir": logdir, + "name": "diff_tb", + "log_graph": True + } + } + } + + default_logger_cfg = default_logger_cfgs["tensorboard"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = default_logger_cfg + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # config the strategy, defualt is ddp + if "strategy" in trainer_config: + strategy_cfg = trainer_config["strategy"] + print("Using strategy: {}".format(strategy_cfg["target"])) + else: + strategy_cfg = { + "target": "pytorch_lightning.strategies.DDPStrategy", + "params": { + "find_unused_parameters": False + } + } + print("Using strategy: DDPStrategy") + + trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) + + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 750, + "max_images": 4, + "clamp": True + } + }, + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "main.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': + {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = trainer_config["devices"] + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb; + pudb.set_trace() + + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # run + if opt.train: + try: + for name, m in model.named_parameters(): + print(name) + trainer.fit(model, data) + except Exception: + melk() + raise + # if not opt.no_test and not trainer.interrupted: + # trainer.test(model, data) + except Exception: + if opt.debug and trainer.global_rank == 0: + try: + import pudb as debugger + except ImportError: + import pdb as debugger + debugger.post_mortem() + raise + finally: + # move newly created debug project to debug_runs + if opt.debug and not opt.resume and trainer.global_rank == 0: + dst, name = os.path.split(logdir) + dst = os.path.join(dst, "debug_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + os.rename(logdir, dst) + if trainer.global_rank == 0: + print(trainer.profiler.summary()) diff --git a/examples/tutorial/stable_diffusion/requirements.txt b/examples/tutorial/stable_diffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..54bc00029642135496212e63737cd2ce6390c616 --- /dev/null +++ b/examples/tutorial/stable_diffusion/requirements.txt @@ -0,0 +1,21 @@ +albumentations==0.4.3 +diffusers +pudb==2019.2 +datasets +invisible-watermark +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +omegaconf==2.1.1 +multiprocess +test-tube>=0.7.5 +streamlit>=0.73.1 +einops==0.3.0 +torch-fidelity==0.3.0 +transformers==4.19.2 +torchmetrics==0.6.0 +kornia==0.6 +opencv-python==4.6.0.66 +prefetch_generator +-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +-e git+https://github.com/openai/CLIP.git@main#egg=clip +-e . diff --git a/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh b/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh new file mode 100644 index 0000000000000000000000000000000000000000..a8d79e99ccdff0a8d8762f23f3c0642401f32f6c --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh @@ -0,0 +1,41 @@ +#!/bin/bash +wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip +wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip +wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip +wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip +wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip +wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip +wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip + + + +cd models/first_stage_models/kl-f4 +unzip -o model.zip + +cd ../kl-f8 +unzip -o model.zip + +cd ../kl-f16 +unzip -o model.zip + +cd ../kl-f32 +unzip -o model.zip + +cd ../vq-f4 +unzip -o model.zip + +cd ../vq-f4-noattn +unzip -o model.zip + +cd ../vq-f8 +unzip -o model.zip + +cd ../vq-f8-n256 +unzip -o model.zip + +cd ../vq-f16 +unzip -o model.zip + +cd ../.. \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/download_models.sh b/examples/tutorial/stable_diffusion/scripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..84297d7b8b9a78d241edcd5adaf7d9aa273790de --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/download_models.sh @@ -0,0 +1,49 @@ +#!/bin/bash +wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip +wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip +wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip +wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip +wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip +wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip +wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip +wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip +wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip + + + +cd models/ldm/celeba256 +unzip -o celeba-256.zip + +cd ../ffhq256 +unzip -o ffhq-256.zip + +cd ../lsun_churches256 +unzip -o lsun_churches-256.zip + +cd ../lsun_beds256 +unzip -o lsun_beds-256.zip + +cd ../text2img256 +unzip -o model.zip + +cd ../cin256 +unzip -o model.zip + +cd ../semantic_synthesis512 +unzip -o model.zip + +cd ../semantic_synthesis256 +unzip -o model.zip + +cd ../bsr_sr +unzip -o model.zip + +cd ../layout2img-openimages256 +unzip -o model.zip + +cd ../inpainting_big +unzip -o model.zip + +cd ../.. diff --git a/examples/tutorial/stable_diffusion/scripts/img2img.py b/examples/tutorial/stable_diffusion/scripts/img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..421e2151d9e9de75a142f5d5f532333645a36287 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/img2img.py @@ -0,0 +1,293 @@ +"""make variations of input image""" + +import argparse, os, sys, glob +import PIL +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +from torch import autocast +from contextlib import nullcontext +import time +from pytorch_lightning import seed_everything + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--init-img", + type=str, + nargs="?", + help="path to the input image" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save indiviual samples. For speed measurements.", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( + "--n_samples", + type=int, + default=2, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.75, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + + opt = parser.parse_args() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + raise NotImplementedError("PLMS sampler not (yet) supported") + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + assert os.path.isfile(opt.init_img) + init_image = load_img(opt.init_img).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + + assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(opt.strength * opt.ddim_steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc,) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if not opt.skip_save: + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/stable_diffusion/scripts/inpaint.py b/examples/tutorial/stable_diffusion/scripts/inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e6387a9a3b0afa73fae8af25f43a8ba856240e --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/inpaint.py @@ -0,0 +1,98 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + opt = parser.parse_args() + + masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) + images = [x.replace("_mask.png", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], + size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1]-1,)+c.shape[2:] + samples_ddim, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/tutorial/stable_diffusion/scripts/knn2img.py b/examples/tutorial/stable_diffusion/scripts/knn2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eaaecab53eac9c97051c9a5cb457a240679725 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py b/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..876fe3c3642fcc8c7209e4f763c0134166615f78 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") diff --git a/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py b/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a32e66d44cf2479d4dcc05d469cf7b4210d2c67d --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py @@ -0,0 +1,37 @@ +import os +import sys +from copy import deepcopy + +import yaml +from datetime import datetime + +from diffusers import StableDiffusionPipeline +import torch +from ldm.util import instantiate_from_config +from main import get_parser + +if __name__ == "__main__": + with torch.no_grad(): + yaml_path = "../../train_colossalai.yaml" + with open(yaml_path, 'r', encoding='utf-8') as f: + config = f.read() + base_config = yaml.load(config, Loader=yaml.FullLoader) + unet_config = base_config['model']['params']['unet_config'] + diffusion_model = instantiate_from_config(unet_config).to("cuda:0") + + pipe = StableDiffusionPipeline.from_pretrained( + "/data/scratch/diffuser/stable-diffusion-v1-4" + ).to("cuda:0") + dif_model_2 = pipe.unet + + random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") + random_input_2 = torch.clone(random_input_).to("cuda:0") + time_stamp = torch.randint(20, (4,)).to("cuda:0") + time_stamp2 = torch.clone(time_stamp).to("cuda:0") + context_ = torch.rand((4, 77, 768)).to("cuda:0") + context_2 = torch.clone(context_).to("cuda:0") + + out_1 = diffusion_model(random_input_, time_stamp, context_) + out_2 = dif_model_2(random_input_2, time_stamp2, context_2) + print(out_1.shape) + print(out_2['sample'].shape) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py b/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..f93f8a6e70763c0e284157bc8225827520b2f5ef --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py @@ -0,0 +1,18 @@ +import cv2 +import fire +from imwatermark import WatermarkDecoder + + +def testit(img_path): + bgr = cv2.imread(img_path) + decoder = WatermarkDecoder('bytes', 136) + watermark = decoder.decode(bgr, 'dwtDct') + try: + dec = watermark.decode('utf-8') + except: + dec = "null" + print(dec) + + +if __name__ == "__main__": + fire.Fire(testit) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/train_searcher.py b/examples/tutorial/stable_diffusion/scripts/train_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7904889c0145f9fb740fd4ae8e45c08728b255 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/txt2img.py b/examples/tutorial/stable_diffusion/scripts/txt2img.py new file mode 100644 index 0000000000000000000000000000000000000000..59c16a1db87123d7d02ff45f3012a196f2a1a6e0 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/txt2img.py @@ -0,0 +1,344 @@ +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + opt = parser.parse_args() + + if opt.laion400m: + print("Falling back to LAION 400M model...") + opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" + opt.ckpt = "models/ldm/text2img-large/model.ckpt" + opt.outdir = "outputs/txt2img-samples-laion400m" + + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "StableDiffusionV1" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not opt.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_checked_image_torch) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/stable_diffusion/setup.py b/examples/tutorial/stable_diffusion/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a24d541676407eee1bea271179ffd1d80c6a8e79 --- /dev/null +++ b/examples/tutorial/stable_diffusion/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='latent-diffusion', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/train.sh b/examples/tutorial/stable_diffusion/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..63abcadbf62ba3c3d6773ec242852f264bddacce --- /dev/null +++ b/examples/tutorial/stable_diffusion/train.sh @@ -0,0 +1,4 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 + +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..ac31ace4bfae025025b1098719aba873db615d1c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + cpu: tests which can run on CPU + gpu: tests which requires a single GPU + dist: tests which are run in a multi-GPU or multi-machine environment + experiment: tests for experimental features \ No newline at end of file diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9e8960d2eaf363cd952f7d7b69818c9f58bb528 --- /dev/null +++ b/requirements/requirements-test.txt @@ -0,0 +1,12 @@ +fbgemm-gpu==0.2.0 +pytest +torchvision +transformers +timm +titans +torchaudio +torchrec==0.2.0 +contexttimer +einops +triton==2.0.0.dev20221011 +git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ac4a3c606e28d1f0bce7c709c7160cfff4aa01f --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,9 @@ +numpy +tqdm +psutil +packaging +pre-commit +rich +click +fabric +contexttimer diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2c274edefe2f6386ba1b1c2bbccdd326533365 --- /dev/null +++ b/setup.py @@ -0,0 +1,371 @@ +import os +import re +import subprocess + +from setuptools import Extension, find_packages, setup +import torch +from typing import Optional, Union +from pathlib import Path + +if torch.__version__ >= '1.5': + from torch.utils.cpp_extension import ROCM_HOME + if ((torch.version.hip is not None) and (ROCM_HOME is not None)): + CUDA_HOME = ROCM_HOME + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +build_cuda_ext = False +build_hip_ext = True +ext_modules = [] + +if int(os.environ.get('NO_CUDA_EXT', '0')) == 1: + build_cuda_ext = False + +if int(os.environ.get('NO_HIP_EXT', '0')) == 1: + build_hip_ext = False + + +def get_cuda_bare_metal_version(cuda_dir): + if build_cuda_ext == True: + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + else: + raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "--version"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("version:") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + if build_cuda_ext == True: + torch_binary_major = torch.version.cuda.split(".")[0] + torch_binary_minor = torch.version.cuda.split(".")[1] + else: + torch_binary_major = torch.version.hip.split(".")[0] + torch_binary_minor = torch.version.hip.split(".")[1] + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if bare_metal_major != torch_binary_major: + print(f'The detected CUDA version ({raw_output}) mismatches the version that was used to compile PyTorch ' + f'({torch.version.cuda}). CUDA extension will not be installed.') + return False + + if bare_metal_minor != torch_binary_minor: + print("\nWarning: Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + f"Pytorch binaries were compiled with Cuda {torch.version.cuda}.\n" + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. ") + return True + + +def check_cuda_availability(cuda_dir): + if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query + # torch.cuda.get_device_capability(), which will fail if you are compiling in an environment + # without visible GPUs (e.g. during an nvidia-docker build command). + print( + '\nWarning: Torch did not find available GPUs on this system.\n', + 'If your intention is to cross-compile, this is not an error.\n' + 'By default, Colossal-AI will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' + 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n' + 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n' + 'If you wish to cross-compile for a single specific architecture,\n' + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + _, bare_metal_major, _ = get_cuda_bare_metal_version(cuda_dir) + if int(bare_metal_major) == 11: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + return False + + if cuda_dir is None: + print("nvcc was not found. CUDA extension will not be installed. If you're installing within a container from " + "https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") + return False + return True + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open('README.md', encoding='utf-8') as f: + return f.read() + +def get_sha(root: Union[str, Path]) -> str: + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip() + except Exception: + return 'Unknown' + +def get_version_add(sha: Optional[str] = None) -> str: + cai_root = os.path.dirname(os.path.abspath(__file__)) + add_version_path = os.path.join(os.path.join(cai_root, "colossalai"), "version.py") + if sha != 'Unknown': + if sha is None: + sha = get_sha(cai_root) + version = 'git' + sha[:7] + + if os.getenv('COLOSSALAI_BUILD_VERSION'): + version_dtk = os.getenv('COLOSSALAI_BUILD_VERSION', "") + version += "." + version_dtk + + with open(add_version_path, encoding="utf-8",mode="a") as file: + file.write("__version__=__version__+'+{}'\n".format(version)) + file.close() + + +def get_version(): + setup_file_path = os.path.abspath(__file__) + project_path = os.path.dirname(setup_file_path) + version_txt_path = os.path.join(project_path, 'version.txt') + version_py_path = os.path.join(project_path, 'colossalai/version.py') + + with open(version_txt_path) as f: + version = f.read().strip() + if build_cuda_ext: + torch_version = '.'.join(torch.__version__.split('.')[:2]) + cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)[1:]) + version += f'+torch{torch_version}cu{cuda_version}' + + # write version into version.py + with open(version_py_path, 'w') as f: + f.write(f"__version__ = '{version}'\n") + + if build_hip_ext: + get_version_add() + with open(version_py_path, encoding='utf-8') as f: + exec(compile(f.read(), version_py_path, 'exec')) + return locals()['__version__'] + + return version + + +if build_cuda_ext or build_hip_ext: + build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) + +try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10): + raise RuntimeError("Colossal-AI requires Pytorch 1.10 or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/") +except ImportError: + raise ModuleNotFoundError('torch is not found. You need to install PyTorch before installing Colossal-AI.') + + +if build_hip_ext: + # Set up macros for forward/backward compatibility hack around + # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e + # and + # https://github.com/NVIDIA/apex/issues/456 + # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac + version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + if build_hip_ext: + hip_macros = ['-DCOLOSSAL_HIP'] + + def cuda_ext_helper(name, sources, extra_cuda_flags): + return CUDAExtension( + name=name, + sources=[os.path.join('colossalai/kernel/hip_native/csrc', path) for path in sources], + include_dirs=[os.path.join( + this_dir, 'colossalai/kernel/hip_native/csrc/kernels/include')] + [os.path.join(this_dir, 'colossalai/kernel/hip_native/csrc')] + ['/opt/dtk/hiprand/include'] + ['/opt/dtk/rocrand/include'], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros + hip_macros, + 'nvcc': ['-O3'] + version_dependent_macros + hip_macros + extra_cuda_flags}) + + from torch.utils.hipify import hipify_python + hipify_python.hipify( + project_directory=this_dir, + output_directory=this_dir, + includes="colossalai/kernel/cuda_native/*", + show_detailed=True, + is_pytorch_extension=True, + ) + + cc_flag = [] + + extra_cuda_flags = ['-lineinfo'] + + ext_modules.append( + cuda_ext_helper('colossalai._C.fused_optim', [ + 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.hip', 'multi_tensor_scale_kernel.hip', + 'multi_tensor_adam.hip', 'multi_tensor_l2norm_kernel.hip', 'multi_tensor_lamb.hip'], + extra_cuda_flags + cc_flag )) + + + extra_cuda_flags = [ + '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__' + ] + + ext_modules.append( + cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax', + ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_hip.hip'], + extra_cuda_flags + cc_flag)) + + ext_modules.append( + cuda_ext_helper('colossalai._C.scaled_masked_softmax', + ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_hip.hip'], extra_cuda_flags + cc_flag)) + + ext_modules.append( + cuda_ext_helper('colossalai._C.moe', ['moe_hip.cpp', 'moe_hip_kernel.hip'], extra_cuda_flags + cc_flag)) + + extra_cuda_flags = [] + + ext_modules.append( + cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_hip.cpp', 'layer_norm_hip_kernel.hip'], + extra_cuda_flags + cc_flag)) + + extra_cuda_flags = [ + '-std=c++14', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', '-U__HIP_NO_HALF2_OPERATORS__', + '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + + ext_modules.append( + cuda_ext_helper('colossalai._C.multihead_attention', [ + 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.hip', 'kernels/transform_kernels.hip', + 'kernels/dropout_kernels.hip', 'kernels/normalize_kernels.hip', 'kernels/softmax_kernels.hip', + 'kernels/general_kernels.hip', 'kernels/hip_util.hip' + ], extra_cuda_flags + cc_flag)) + + extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] + ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags + extra_cxx_flags)) + + +build_cuda_ext = False +if build_cuda_ext: + # Set up macros for forward/backward compatibility hack around + # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e + # and + # https://github.com/NVIDIA/apex/issues/456 + # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac + version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + + def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): + return CUDAExtension( + name=name, + sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources], + include_dirs=[os.path.join(this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')], + extra_compile_args={ + 'cxx': ['-O3'] + version_dependent_macros + extra_cxx_flags, + 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags) + }) + + cc_flag = [] + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60: + cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + + extra_cuda_flags = ['-lineinfo'] + + ext_modules.append( + cuda_ext_helper('colossalai._C.fused_optim', [ + 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', + 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' + ], extra_cuda_flags + cc_flag)) + + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ] + + ext_modules.append( + cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax', + ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'], + extra_cuda_flags + cc_flag)) + + ext_modules.append( + cuda_ext_helper('colossalai._C.scaled_masked_softmax', + ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) + + ext_modules.append( + cuda_ext_helper('colossalai._C.moe', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag)) + + extra_cuda_flags = ['-maxrregcount=50'] + + ext_modules.append( + cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], + extra_cuda_flags + cc_flag)) + + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + + ext_modules.append( + cuda_ext_helper('colossalai._C.multihead_attention', [ + 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', + 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', + 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' + ], extra_cuda_flags + cc_flag)) + + extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] + ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags)) + +setup(name='colossalai', + version=get_version(), + packages=find_packages(exclude=( + 'benchmark', + 'docker', + 'tests', + 'docs', + 'examples', + 'tests', + 'scripts', + 'requirements', + '*.egg-info', + )), + description='An integrated large-scale model training system with efficient parallelization techniques', + long_description=fetch_readme(), + long_description_content_type='text/markdown', + license='Apache Software License 2.0', + url='https://www.colossalai.org', + project_urls={ + 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', + 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', + 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', + 'Documentation': 'http://colossalai.readthedocs.io', + 'Github': 'https://github.com/hpcaitech/ColossalAI', + }, + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements('requirements/requirements.txt'), + entry_points=''' + [console_scripts] + colossalai=colossalai.cli:cli + ''', + python_requires='>=3.6', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: Apache Software License', + 'Environment :: GPU :: NVIDIA CUDA', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: System :: Distributed Computing', + ], + package_data={'colossalai': ['_C/*.pyi']}) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e498786fb1399ed034eec31c398d861f3fb90226 --- /dev/null +++ b/tests/components_to_test/__init__.py @@ -0,0 +1,18 @@ +from . import ( + bert, + gpt2, + hanging_param_model, + inline_op_model, + nested_model, + repeated_computed_layers, + resnet, + simple_net, +) +from .utils import run_fwd_bwd + +from . import albert # isort:skip + +__all__ = [ + 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', + 'simple_net', 'run_fwd_bwd', 'albert' +] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b6bc89a83e0ae49a4045800b6e9d57ff848604 --- /dev/null +++ b/tests/components_to_test/albert.py @@ -0,0 +1,59 @@ +import torch +import transformers +from packaging import version +from transformers import AlbertConfig, AlbertForSequenceClassification + +from .bert import get_bert_data_loader +from .registry import non_distributed_component_funcs + + +@non_distributed_component_funcs.register(name='albert') +def get_training_components(): + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 32 + + def bert_model_builder(checkpoint: bool = False): + config = AlbertConfig(vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + print('building AlbertForSequenceClassification model') + + # adapting huggingface BertForSequenceClassification for single unitest calling interface + class ModelAaptor(AlbertForSequenceClassification): + + def forward(self, input_ids, labels): + """ + inputs: data, label + outputs: loss + """ + return super().forward(input_ids=input_ids, labels=labels)[0] + + model = ModelAaptor(config) + # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): + # model.gradient_checkpointing_enable() + + return model + + is_distrbuted = torch.distributed.is_initialized() + trainloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + testloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + + criterion = None + return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..c1faa6f9d892650c4717d7baa359165822a8165b --- /dev/null +++ b/tests/components_to_test/bert.py @@ -0,0 +1,84 @@ +import torch +import transformers +from packaging import version +from torch.utils.data import SequentialSampler +from transformers import BertConfig, BertForSequenceClassification + +from .registry import non_distributed_component_funcs + + +def get_bert_data_loader( + n_class, + batch_size, + total_samples, + sequence_length, + device=torch.device('cpu:0'), + is_distrbuted=False, +): + train_data = torch.randint( + low=0, + high=n_class, + size=(total_samples, sequence_length), + device=device, + dtype=torch.long, + ) + train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + if is_distrbuted: + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + sampler = SequentialSampler(train_dataset) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) + return train_loader + + +@non_distributed_component_funcs.register(name='bert') +def get_training_components(): + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 32 + + def bert_model_builder(checkpoint: bool = False): + config = BertConfig(vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + print('building BertForSequenceClassification model') + + # adapting huggingface BertForSequenceClassification for single unitest calling interface + class ModelAaptor(BertForSequenceClassification): + + def forward(self, input_ids, labels): + """ + inputs: data, label + outputs: loss + """ + return super().forward(input_ids=input_ids, labels=labels)[0] + + model = ModelAaptor(config) + if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): + model.gradient_checkpointing_enable() + + return model + + is_distrbuted = torch.distributed.is_initialized() + trainloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + testloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + + criterion = None + return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..fe25b4923fa27493be7f006956ce642be6047503 --- /dev/null +++ b/tests/components_to_test/gpt2.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +from colossalai.utils.cuda import get_current_device + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class DummyDataLoader(DummyDataGenerator): + vocab_size = 128 + batch_size = 4 + seq_len = 64 + + def generate(self): + input_ids = torch.randint(0, + DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), + device=get_current_device()) + return input_ids, input_ids + + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50304, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids): + # Only return lm_logits + attention_mask = torch.ones_like(input_ids) + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_micro(checkpoint=True): + return GPTLMModel(checkpoint=checkpoint, + hidden_size=32, + num_layers=2, + num_attention_heads=4, + max_seq_len=64, + vocab_size=128) + + +def gpt2_s(checkpoint=True): + return GPTLMModel(checkpoint=checkpoint) + + +def gpt2_m(checkpoint=True): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +@non_distributed_component_funcs.register(name='gpt2') +def get_training_components(): + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = GPTLMLoss() + return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py new file mode 100644 index 0000000000000000000000000000000000000000..329a08ea28f0224226c6061c2894982f3cd1d397 --- /dev/null +++ b/tests/components_to_test/hanging_param_model.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class HangingParamModule(CheckpointModule): + """ + Hanging Parameter: a parameter dose not belong to a leaf Module. + It has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='hanging_param_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return HangingParamModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f061d48f92c6eb7a2095fed039f2d2e49c1e07d3 --- /dev/null +++ b/tests/components_to_test/inline_op_model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class InlineOpModule(CheckpointModule): + """ + a module with inline Ops + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.proj2 = nn.Linear(8, 8) + + def forward(self, x): + + x = self.proj1(x) + # inline add_ + x.add_(10) + x = self.proj2(x) + # inline relu_ + x = torch.relu_(x) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='inline_op_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return InlineOpModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py new file mode 100644 index 0000000000000000000000000000000000000000..339084639244ef4f434218ff77673b46b9deb5c1 --- /dev/null +++ b/tests/components_to_test/nested_model.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils import DummyDataGenerator + + +class SubNet(nn.Module): + + def __init__(self, out_features) -> None: + super().__init__() + self.bias = nn.Parameter(torch.zeros(out_features)) + + def forward(self, x, weight): + return F.linear(x, weight, self.bias) + + +class NestedNet(CheckpointModule): + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint) + self.fc1 = nn.Linear(5, 5) + self.sub_fc = SubNet(5) + self.fc2 = nn.Linear(5, 2) + + def forward(self, x): + x = self.fc1(x) + x = self.sub_fc(x, self.fc1.weight) + x = self.fc1(x) + x = self.fc2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 5) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='nested_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return NestedNet(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..728ed9eba6ea176b70f152751f877781a5beb214 --- /dev/null +++ b/tests/components_to_test/registry.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + + +class Registry: + + def __init__(self): + self._registry = dict() + + def register(self, name): + assert name not in self._registry + + def _regsiter(callable_): + self._registry[name] = callable_ + + return _regsiter + + def get_callable(self, name: str): + return self._registry[name] + + def __iter__(self): + self._idx = 0 + self._len = len(self._registry) + self._names = list(self._registry.keys()) + return self + + def __next__(self): + if self._idx < self._len: + key = self._names[self._idx] + callable_ = self._registry[key] + self._idx += 1 + return callable_ + else: + raise StopIteration + + +non_distributed_component_funcs = Registry() +model_paralle_component_funcs = Registry() + +__all__ = ['non_distributed_component_funcs', 'model_paralle_component_funcs'] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f84bd0e203eb3fd2d23ae627048f422cfc8a51 --- /dev/null +++ b/tests/components_to_test/repeated_computed_layers.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import torch +import torch.nn as nn + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class NetWithRepeatedlyComputedLayers(CheckpointModule): + """ + This model is to test with layers which go through forward pass multiple times. + In this model, the fc1 and fc2 call forward twice + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.fc1 = nn.Linear(5, 5) + self.fc2 = nn.Linear(5, 5) + self.fc3 = nn.Linear(5, 2) + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 5) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='repeated_computed_layers') +def get_training_components(): + + def model_builder(checkpoint=False): + return NetWithRepeatedlyComputedLayers(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..193832ebc12da0c3f6d18a8bb21fa9a67a2ac02c --- /dev/null +++ b/tests/components_to_test/resnet.py @@ -0,0 +1,33 @@ +from torchvision.models import resnet18 +from .registry import non_distributed_component_funcs +from pathlib import Path +import os +import torch +from torchvision.transforms import transforms +from torchvision.datasets import CIFAR10 +from colossalai.utils import get_dataloader + + +def get_cifar10_dataloader(train): + # build dataloaders + dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + train=train, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])) + dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) + return dataloader + + +@non_distributed_component_funcs.register(name='resnet18') +def get_resnet_training_components(): + + def model_builder(checkpoint=False): + return resnet18(num_classes=10) + + trainloader = get_cifar10_dataloader(train=True) + testloader = get_cifar10_dataloader(train=False) + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9d7ebc0b1a1a699e427905269e8e8a9b0936bf --- /dev/null +++ b/tests/components_to_test/simple_net.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +from colossalai.nn import CheckpointModule +from colossalai.utils.cuda import get_current_device + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class SimpleNet(CheckpointModule): + """ + In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.embed = nn.Embedding(20, 4) + self.proj1 = nn.Linear(4, 8) + self.ln1 = nn.LayerNorm(8) + self.proj2 = nn.Linear(8, 4) + self.ln2 = nn.LayerNorm(4) + self.classifier = nn.Linear(4, 4) + + def forward(self, x): + x = self.embed(x) + x = self.proj1(x) + x = self.ln1(x) + x = self.proj2(x) + x = self.ln2(x) + x = self.classifier(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) + label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) + return data, label + + +@non_distributed_component_funcs.register(name='simple_net') +def get_training_components(): + + def model_builder(checkpoint=False): + return SimpleNet(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/__init__.py b/tests/components_to_test/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f223f7d322cb8e7d023139a944edafbbc161ba6e --- /dev/null +++ b/tests/components_to_test/utils/__init__.py @@ -0,0 +1,2 @@ +from .dummy_data_generator import DummyDataGenerator +from .executor import run_fwd_bwd diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab33e86de230da778d0e2bc1b8d1e8e581c1f79 --- /dev/null +++ b/tests/components_to_test/utils/dummy_data_generator.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length diff --git a/tests/components_to_test/utils/executor.py b/tests/components_to_test/utils/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..e77152561e6ce02c26e48152b38707dd596b393c --- /dev/null +++ b/tests/components_to_test/utils/executor.py @@ -0,0 +1,29 @@ +import torch + + +def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: + """run_fwd_bwd + run fwd and bwd for the model + + Args: + model (torch.nn.Module): a PyTorch model + data (torch.Tensor): input data + label (torch.Tensor): label + criterion (Optional[Callable]): a function of criterion + + Returns: + torch.Tensor: loss of fwd + """ + if criterion: + y = model(data) + y = y.float() + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if optimizer: + optimizer.backward(loss) + else: + loss.backward() + return loss diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..95c5686ae80dcac6f20fea396a092ae04111b355 --- /dev/null +++ b/tests/test_amp/test_naive_fp16.py @@ -0,0 +1,95 @@ +import torch +import colossalai +import torch.multiprocessing as mp +from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp + +from tests.components_to_test.registry import non_distributed_component_funcs + +import copy +import pytest +from functools import partial + + +def check_equal(a, b): + """ + This function checks if two tensors are equal within tolerance + """ + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + + +def run_naive_amp(): + """ + In this test, we compare the naive fp16 optimizer implemented in colossalai + and fp32 torch optimizer + """ + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # create layer + test_models = ['repeated_computed_layers', 'nested_model', 'resnet18'] + for test_name in test_models: + get_component_func = non_distributed_component_funcs.get_callable(test_name) + model_builder, train_dataloader, _, optim_class, _ = get_component_func() + + # create model + naive_amp_model = model_builder(checkpoint=True).cuda() + apex_amp_model = copy.deepcopy(naive_amp_model) + + # create optimizer + naive_amp_optimizer = optim_class(naive_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3) + + # inject naive and apex amp + naive_amp_config = dict(initial_scale=128) + naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, + naive_amp_config) + apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) + apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) + + # create data + data_iter = iter(train_dataloader) + data, label = next(data_iter) + data = data.cuda() + + # forward pass + naive_amp_output = naive_amp_model(data) + apex_amp_output = apex_amp_model(data) + assert_close_loose(naive_amp_output, apex_amp_output) + + # backward + naive_amp_optimizer.backward(naive_amp_output.mean()) + apex_amp_optimizer.backward(apex_amp_output.mean()) + + # check grad + for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(naive_amp_param.grad, apex_amp_param.grad) + + # step + naive_amp_optimizer.step() + apex_amp_optimizer.step() + + # check updated param + for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(naive_amp_param, apex_amp_param) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + run_naive_amp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_naive_amp(): + world_size = 1 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_naive_amp() diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_amp/test_torch_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..1372b08fa1fed0cdc56bc860c2b8f199908de85c --- /dev/null +++ b/tests/test_amp/test_torch_fp16.py @@ -0,0 +1,90 @@ +import torch +import colossalai +import torch.multiprocessing as mp +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.amp import convert_to_torch_amp, convert_to_apex_amp + +import copy +import pytest +from functools import partial + + +def run_torch_amp(): + """ + In this test, we compare the torch amp and apex amp implemented in colossalai + """ + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # create layer + test_models = ['resnet18', 'simple_net'] + for test_name in test_models: + get_component_func = non_distributed_component_funcs.get_callable(test_name) + model_builder, train_dataloader, _, optim_class, _ = get_component_func() + + # create model + torch_amp_model = model_builder(checkpoint=True).cuda() + apex_amp_model = copy.deepcopy(torch_amp_model) + + # create optimizer + torch_amp_optimizer = optim_class(torch_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3) + + # inject torch and apex amp + torch_amp_config = dict(init_scale=1280, enabled=True) + torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, + torch_amp_optimizer, + amp_config=torch_amp_config) + apex_amp_config = dict(opt_level='O1', loss_scale=1280) + apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) + + # create data + data_iter = iter(train_dataloader) + data, label = next(data_iter) + data = data.cuda() + + # forward pass + torch_amp_output = torch_amp_model(data) + apex_amp_output = apex_amp_model(data) + assert_close_loose(torch_amp_output, apex_amp_output) + + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param, apex_amp_param) + + # backward + torch_amp_optimizer.backward(torch_amp_output.mean()) + apex_amp_optimizer.backward(apex_amp_output.mean()) + + # check grad + # In apex amp, grad is not scaled before backward, but torch amp does + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) + + # step + torch_amp_optimizer.step() + apex_amp_optimizer.step() + + # check updated param and grad + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad) + assert_close_loose(torch_amp_param, apex_amp_param) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + run_torch_amp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_torch_amp(): + world_size = 1 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_torch_amp() diff --git a/tests/test_auto_parallel/__init__.py b/tests/test_auto_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_tensor_shard/__init__.py b/tests/test_auto_parallel/test_tensor_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..e666cb1753a730b116c496f8987acc502fa6129e --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -0,0 +1,172 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def check_linear_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel(4, 8).cuda() + input = torch.rand(4, 4).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + linear_node = node_list[3] + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + + gm = runtime_apply_pass(gm) + gm.recompile() + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, output_compare) + + +def check_conv_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel(3, 6, 2).cuda() + input = torch.rand(4, 3, 64, 64).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + + node_list = list(graph.nodes) + conv_node = node_list[3] + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + + gm = runtime_apply_pass(gm) + gm.recompile() + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, output_compare) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bias_addition_module(): + world_size = 4 + run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) + mp.spawn(run_func_linear, nprocs=world_size) + run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port()) + mp.spawn(run_func_conv, nprocs=world_size) + + +if __name__ == '__main__': + test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..5607587496f354a9b5a9623a3ea54c539dfd7ade --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -0,0 +1,64 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.utils import ( + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec + + +def test_is_broadcastable(): + x1 = torch.rand(4, 4, 8) + x2 = torch.rand(1, 8) + assert is_broadcastable(x1.shape, x2.shape) + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(2, 8) + assert is_broadcastable(x1.shape, x2.shape) + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(4, 8) + assert not is_broadcastable(x1.shape, x2.shape) + + +def test_get_broadcast_shape(): + x1 = torch.rand(4, 4, 8) + x2 = torch.rand(1, 8) + assert get_broadcast_shape(x1.shape, x2.shape) == [4, 4, 8] + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(2, 8) + assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8] + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(8) + assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8] + + +def test_recover_sharding_spec_for_broadcast_shape(): + x1 = torch.rand(4, 1, 8) + x2 = torch.rand(2, 8) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + broadcast_shape = get_broadcast_shape(x1.shape, x2.shape) + logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh, + dim_partition_dict={ + 0: [0], + 1: [1] + }, + entire_shape=broadcast_shape) + physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec_for_x1, broadcast_shape, x1.shape) + print(physical_sharding_spec_for_x1) + + assert physical_sharding_spec_for_x1.entire_shape == x1.shape + # dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore + assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]} + assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R'] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..96d96a4594c3506d4270671c36db37f3d5310656 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py @@ -0,0 +1,96 @@ +from copy import deepcopy +from pickletools import optimize + +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) + self.relu = nn.ReLU() + + def forward(self, x): + x = x * 2 + x = self.conv1(x) + x = x / 2 + x = self.relu(x) + return x + + +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {}) + # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {}) + # return relu + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + # (x, mul):{(0, 0): 0} + # (mul, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002} + # (conv1, truediv):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): inf, (11, 0): inf, (12, 0): inf, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): 0, (6, 1): inf, (7, 1): inf, (8, 1): inf, (9, 1): inf, (10, 1): inf, (11, 1): inf, (12, 1): inf, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): inf, (11, 2): inf, (12, 2): inf, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): inf, (11, 3): inf, (12, 3): inf, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): inf, (11, 4): inf, (12, 4): inf, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): inf, (11, 5): inf, (12, 5): inf, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): inf, (11, 7): inf, (12, 7): inf, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): inf, (9, 8): inf, (10, 8): inf, (11, 8): inf, (12, 8): inf, (13, 8): inf, (14, 8): 0} + # (truediv, relu):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): inf, (6, 1): inf, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): inf, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): inf, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): 0} + # (relu, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002} + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + + # construct all node pairs + all_node_pairs = [] + + for node in graph.nodes: + if node.op == 'output': + continue + for child in node.users.keys(): + all_node_pairs.append((node, child)) + + for node_pair in all_node_pairs: + assert node_pair in cost_graph.edge_costs + + # construct merged node pairs + merged_node_pairs = [] + node_list = list(graph.nodes) + # add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs + merged_node_pairs.append((node_list[0], node_list[4])) + merged_node_pairs.append((node_list[2], node_list[4])) + merged_node_pairs.append((node_list[3], node_list[5])) + merged_node_pairs.append((node_list[5], node_list[6])) + merged_node_pairs.append((node_list[4], node_list[6])) + merged_node_pairs.append((node_list[6], node_list[-1])) + cost_graph.simplify_graph() + for node_pair in all_node_pairs: + if node_pair in merged_node_pairs: + assert node_pair in cost_graph.edge_costs + else: + assert node_pair not in cost_graph.edge_costs + + +if __name__ == '__main__': + test_cost_graph() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3e71551eb2a8b4bc9640772e5fb77863431c79 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py @@ -0,0 +1,118 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +class BNModel(nn.Module): + + def __init__(self, c): + super().__init__() + self.bn = nn.BatchNorm2d(c) + + def forward(self, x): + x = x * 2 + x = self.bn(x) + return x + + +def test_bn_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + + tracer = ColoTracer() + model = BNModel(16) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {}) + # return bn + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, bn, output] + nodes = [node for node in gm.graph.nodes] + + # find the sharding strategies for the input node of the bn node + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(nodes[1]) + sharding_option = (None, 0, 1) + for first_sharding_index in sharding_option: + for second_sharding_index in sharding_option: + if first_sharding_index is not None and second_sharding_index == first_sharding_index: + continue + if first_sharding_index is None: + first_dim_spec = _DimSpec([]) + else: + first_dim_spec = _DimSpec([first_sharding_index]) + + if second_sharding_index is None: + second_dim_spec = _DimSpec([]) + else: + second_dim_spec = _DimSpec([second_sharding_index]) + + replica_dim_spec = _DimSpec([]) + sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec] + sharding_spec = ShardingSpec(device_mesh=device_mesh, + entire_shape=entire_shape, + sharding_sequence=sharding_sequence) + strategy_name = str(sharding_spec.sharding_sequence) + sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) + strategies_vector_for_input.append(sharding_strategy) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + # generate bn strategy + strategies_vector = StrategiesVector(node=nodes[2]) + bn_handler = BatchNormHandler( + node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + ) + bn_handler.register_strategy() + # ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01', + # 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN'] + strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector] + + # RS = RS x S and strategies based on it, such as + # SS = RS x S + assert 'RS0 = RS0 x S0' in strategy_name_list + assert 'S1S0 = RS0 x S0' in strategy_name_list + assert 'RS1 = RS1 x S1' in strategy_name_list + assert 'S0S1 = RS1 x S1' in strategy_name_list + + # RR = RR x R and strategies based on it, such as + # SR = SR x R + assert 'RR = RR x R' in strategy_name_list + assert 'S0R = RR x R' in strategy_name_list + assert 'S1R = RR x R' in strategy_name_list + assert 'S01R = RR x R' in strategy_name_list + + # RS01 = RS01 x S01 + assert 'RS01 = RS01 x S01' in strategy_name_list + + # SR = SR x R WITH SYNC_BN + assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list + assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list + + # SS = SS x S WITH SYNC_BN + assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list + assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list + + # S01R = S01R x R WITH SYNC_BN + assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + + +if __name__ == '__main__': + test_bn_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..7adc211cfc07cceb609863420553b1fae97a4a93 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py @@ -0,0 +1,75 @@ +from cProfile import run + +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2) + + def forward(self, x): + x1 = self.conv1(x) + x2 = x1 + 1 + x1 = torch.reshape(x1, [1, -1, 64, 1]) + x3 = self.conv2(x1) + x3 = torch.reshape(x3, [4, 1, 64, -1]) + x = x1 + x3 + + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {}) + # %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {}) + # %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {}) + # %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {}) + # %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {}) + # return add_1 + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + # [x, conv1, add, reshape, conv2, reshape_1, add_1, output] + nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + strategy_map = strategies_constructor.strategy_map + # check a tensor add with a scalar case + conv1_strategies = strategy_map[nodes[1]] + add_strategies = strategy_map[nodes[2]] + add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies] + for strategy in conv1_strategies: + assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list + + # check two tensors element-wise add case + add_1_strategies = strategy_map[nodes[6]] + assert len(add_1_strategies) == 25 + + +if __name__ == '__main__': + test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..426d179f10d5c4fdbdc861bb9e563241eedf4c3b --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py @@ -0,0 +1,54 @@ +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class MatmulModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x1, x2): + x = torch.matmul(x1, x2) + + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + model = MatmulModel() + input_sample = {'x1': torch.rand(4, 4, 8).to('meta'), 'x2': torch.rand(4, 1, 8, 4).to('meta')} + # graph(): + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {}) + # return matmul + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + # [x1, x2, matmul, output] + nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + strategy_map = strategies_constructor.strategy_map + matmul_strategies = strategy_map[nodes[2]] + assert len(matmul_strategies) == 30 + + +if __name__ == '__main__': + test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..9342e06a040a67595b264c64d318f23fa634cd0f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py @@ -0,0 +1,90 @@ +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) + + def forward(self, x): + x = x * 2 + x = self.conv(x) + return x + + +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # return add + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + conv_node = list(graph.nodes)[4] + # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] + strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RS = RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR= RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + assert 'RR = RS01 x S01R' in strategy_name_list + + +if __name__ == '__main__': + test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2dba1611f0b5ddf4572f62b2a3947d4b486f10 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py @@ -0,0 +1,83 @@ +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec + + +class LinearModel(nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + + def forward(self, x): + x = x * 2 + x = self.linear(x) + return x + + +@pytest.mark.skip('F.linear is not supported in deprecated handler') +def test_dot_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 8)) + + tracer = ColoTracer() + model = LinearModel(8, 16) + input_sample = {'x': torch.rand(4, 8).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # return add + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + linear_node = list(graph.nodes)[4] + + # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] + strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + +if __name__ == '__main__': + test_dot_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..40e227cb53ebc09ae50a209c3a1afe9b0185d707 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py @@ -0,0 +1,70 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest +from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +class LNModel(nn.Module): + + def __init__(self, c): + super().__init__() + self.ln = nn.LayerNorm(c) + + def forward(self, x): + x = x * 2 + x = self.ln(x) + return x + + +def test_bn_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 4, 128)) + + tracer = ColoTracer() + model = LNModel(128) + input_sample = {'x': torch.rand(4, 4, 128).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {}) + # return ln + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, ln, output] + nodes = [node for node in gm.graph.nodes] + sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {}) + sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input) + strategies_vector_for_input = StrategiesVector(nodes[1]) + strategies_vector_for_input.append(sharding_strategy_for_input) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + # generate bn strategy + strategies_vector = StrategiesVector(node=nodes[2]) + ln_handler = LayerNormHandler( + node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + ) + ln_handler.register_strategy() + # ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]', + # '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R'] + strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector] + + assert len(strategy_name_list) == 9 + + +if __name__ == '__main__': + test_bn_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ac9df4cd825b36f46e79aa2ae049327400a8afe4 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) + + def forward(self, x): + x = self.conv(x) + x = torch.flatten(x) + return x + + +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {}) + # return flatten + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + # [x, conv, flatten, output] + nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + strategy_map = strategies_constructor.strategy_map + add_strategies = strategy_map[nodes[5]] + flatten_strategies = strategy_map[nodes[6]] + flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies] + for strategy in add_strategies: + assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list + + +if __name__ == '__main__': + test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..294a59fc8548737920f98dd1d07f30dd0ac7831f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py @@ -0,0 +1,66 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ConvModel(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + + def forward(self, condition, x, y): + output = torch.where(condition, x, y) + + return output + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_where_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = { + 'condition': torch.rand(16, 32).to('meta'), + 'x': torch.rand(16, 32).to('meta'), + 'y': torch.rand(16, 32).to('meta') + } + # graph(): + # %condition : torch.Tensor [#users=1] = placeholder[target=condition] + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %y : torch.Tensor [#users=1] = placeholder[target=y] + # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) + # return where + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + + # [condition, x, y, where, output] + nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + strategy_map = strategies_constructor.strategy_map + # check a tensor add with a scalar case + where_node = strategy_map[nodes[3]] + # ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]', + # '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]', + # '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]'] + assert len(where_node) == 9 + + +if __name__ == '__main__': + test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..3286b325c8ab6e83f1a3221b2eaa5a0d142245d1 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py @@ -0,0 +1,86 @@ +from functools import partial +import pytest +import torch +import torch.multiprocessing as mp +from torch.fx import GraphModule +import torch.nn as nn +import pytest +from colossalai.initialize import launch +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.logging import disable_existing_loggers +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass +from colossalai.auto_parallel.tensor_shard.deprecated import Solver +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x = self.conv(x) + return x + + +def check_apply(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + input = torch.rand(4, 4, 4, 4).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = torch.Size((4, 4, 8, 8)) + + tracer = ColoTracer() + model = ConvModel(4, 4).cuda() + origin_output = model(input) + input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) + shape_consistency_pass(gm) + gm.recompile() + nodes = [node for node in gm.graph.nodes] + # TODO: wrap the gm to avoid the influence of the user training code + output = gm(input, sharding_spec_dict, origin_spec_dict) + assert output.equal(origin_output) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_apply(): + world_size = 4 + run_func = partial(check_apply, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_apply() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..baa70727a2e51075ab945676bb89535616092fda --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py @@ -0,0 +1,79 @@ +from copy import deepcopy + +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated import Solver +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) + self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3) + self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3) + self.relu = nn.ReLU() + + def forward(self, x): + x = x * 2 + x = self.conv1(x) + x = self.conv2(x) + x = x / 2 + x = self.conv3(x) + x = self.relu(x) + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_solver(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {}) + # %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {}) + # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {}) + # %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {}) + # return relu + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + + # [ 0 0 13 13 13 13 13 0] + strategies_combination_list = ret[0] + assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR' + + +if __name__ == '__main__': + test_solver() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e90d6b15308caab36f6f6512f9323526ca8b2115 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py @@ -0,0 +1,81 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from copy import deepcopy +from colossalai.auto_parallel.tensor_shard.deprecated import Solver +import transformers +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag + +BATCH_SIZE = 8 +SEQ_LENGHT = 8 + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12) + model = transformers.GPT2LMHeadModel(config=config) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + + graph = tracer.trace(root=model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + print(graph) + strategies_constructor.build_strategies_and_cost() + for check_node, strategies_vector in strategies_constructor.strategy_map.items(): + print(check_node, len(strategies_vector)) + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + + ret = solver.call_solver_serialized_args() + print(ret) + strategies_list = list(ret[0]) + print(strategies_list) + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_cost_graph() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..415156ed6545d4df5b071875e96a875e3031588a --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py @@ -0,0 +1,94 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from copy import deepcopy +from colossalai.auto_parallel.tensor_shard.deprecated import Solver +from torchvision.models import resnet34, resnet50 +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim * 4) + self.linear2 = torch.nn.Linear(dim * 4, dim) + self.dropout = torch.nn.Dropout(0) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear1(x) + x = self.dropout(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = MLP(32) + + input_sample = {'x': torch.rand(16, 32).to('meta')} + + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) + # %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {}) + # %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {}) + # return linear2 + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + # # megatron mode if no memory constraints + # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + # all sharding on out feature dim if memory budget is not sufficient for megatron mode + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0) + + ret = solver.call_solver_serialized_args() + strategies_list = list(ret[0]) + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(graph.nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_cost_graph() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..9be1a5d963a9dd2b15b51ff844d6643227ffc8a3 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py @@ -0,0 +1,103 @@ +from copy import deepcopy + +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) + + def forward(self, x): + x = x * 2 + x = self.conv(x) + return x + + +def test_strategies_constructor(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # return add + graph = tracer.trace(root=model, meta_args=input_sample) + print(graph) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + assert strategies_constructor.leaf_strategies == [] + assert strategies_constructor.strategy_map == {} + strategies_constructor.build_strategies_and_cost() + + # check leaf_strategies + + # In fast mode, placeholder node only has replica strategy. + assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder' + + # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. + assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0' + + # Third node is conv. + conv_check_list = deepcopy(CONV_STRATEGIES_LIST) + for strategy in strategies_constructor.leaf_strategies[4]: + conv_check_list.remove(strategy.name) + assert len(conv_check_list) == 0 + + # In fast mode, output node only has replica strategy. + assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output' + + # check strategy_map + + nodes = [node for node in graph.nodes] + # In fast mode, placeholder node only has replica strategy. + x = nodes[0] + assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder' + + # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. + mul = nodes[1] + assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0' + + # fifth node is conv. + conv = nodes[4] + conv_check_list = deepcopy(CONV_STRATEGIES_LIST) + for strategy in strategies_constructor.strategy_map[conv]: + conv_check_list.remove(strategy.name) + assert len(conv_check_list) == 0 + + # In fast mode, output node only has replica strategy. + output = nodes[-1] + assert strategies_constructor.strategy_map[output][0].name == 'Replica Output' + + +if __name__ == '__main__': + test_strategies_constructor() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..d573c65908f78ac94b222be4e30b7a6fba9804f5 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py @@ -0,0 +1,214 @@ +import copy +import random +from functools import partial +from typing import Optional, Tuple, Union + +import numpy as np +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +from torch.fx import GraphModule +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + +seed = 128 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +np.random.seed(seed) +random.seed(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + # We temporarily banned the Dropout layer because the rng state need + # to process to get the correct result. + # self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + # TODO: the rng state need to be fixed for distributed runtime + # hidden_states = self.dropout(hidden_states) + return hidden_states + + +def check_mlp_layer(rank, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda') + test_model = copy.deepcopy(model) + test_input = copy.deepcopy(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + } + + graph = tracer.trace(root=model, meta_args=input_sample) + print(graph) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + print(gm) + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() + origin_output = test_model(test_input) + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(cpu_rng_state) + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, origin_output, rtol=1e-03, atol=1e-04) + + #*******************backward starting******************* + cuda_rng_state = torch.cuda.get_rng_state() + output.sum().backward() + torch.cuda.set_rng_state(cuda_rng_state) + origin_output.sum().backward() + origin_param_dict = dict(test_model.named_parameters()) + if rank == 0: + print("*******************backward starting*******************") + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + print(name, param_grad, origin_param_grad) + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, 0, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + print("*******************backward finished*******************") + if rank == 1: + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + if rank == 2: + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, 0, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + if rank == 3: + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + + #*******************backward finished******************* + + #*******************strategy selected******************* + if rank == 0: + print("*******************strategy selected*******************") + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('model_cls', [GPT2MLP]) +@rerun_if_address_is_in_use() +def test_mlp_layer(model_cls): + world_size = 4 + run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..f5de7bf702ff178e7ef88a86b4ddf7e98a0f5833 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser +from colossalai.fx import ColoGraphModule, ColoTracer + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(4, 4) + + def forward(self, x1, x2): + x1 = x1 * 2 + x1 = self.linear1(x1) + x1 = self.relu(x1) + x1 = self.linear2(x1) + out = x1 + x2 + return out + + +def test_liveness_analysis(): + model = LinearModel() + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + 'x1': torch.rand(4, 4, device='meta'), + 'x2': torch.rand(4, 4, device='meta') + }) + gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) + + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + stage_count = len(liveness_list) + + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + assert stage_count == 1 + + # a variable named `relu` must exist + # and this live var must have inplace = True + assert liveness_list[0].all_live_vars.exists('relu') + relu_var = liveness_list[0].all_live_vars.get('relu') + assert relu_var.is_inplace + + # the unique vars must be fewer than the all vars since in-place ops exist + all_live_vars = liveness_list[0].all_live_vars + unique_live_vars = liveness_list[0].unique_live_vars + assert len(unique_live_vars) + 1 == len(all_live_vars) + + +if __name__ == '__main__': + test_liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..f468b1ab2113f0b221d2856c7882c8f8773308a2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -0,0 +1,61 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _ReLU_module_mem_test(rank, world_size, port): + """This function is for ReLU memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.ReLU()).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 1 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_ReLU_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_ReLU_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..7acbbed8f25abf3e71f6a6f517581e6a61de6457 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py @@ -0,0 +1,60 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _batchnorm_module_mem_test(rank, world_size, port): + """This function is for batchnorm memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.BatchNorm2d(128)).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 4 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_batchnorm_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_batchnorm_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..1b745d8906b01f9e373286d87582388061354e41 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -0,0 +1,71 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class BinaryElementwiseOpModule(nn.Module): + + def __init__(self, token=torch.add, shape=64) -> None: + super().__init__() + self.token = token + self.param = nn.Parameter(torch.rand(shape)) + + def forward(self, input): + return input + self.param + + +def _binary_elementwise_mem_test(rank, world_size, port): + """This function is for binary elementwise ops memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() + input = torch.rand(32, 1024).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_binary_elementwise_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..a973a8182cf334c246519a23be7452c442a906d0 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -0,0 +1,113 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class ConvFunctionModule(nn.Module): + + def __init__(self, in_channels=4, out_channels=64, kernel_size=3): + super().__init__() + self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + def forward(self, input): + return nn.functional.conv2d(input, self.conv_weight) + + +def _conv_module_mem_test(rank, bias, world_size, port): + """This function is for conv memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_meta_concrete_info_match(bias=False): + world_size = 4 + run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _conv_function_mem_test(rank, world_size, port): + """This function is for conv function memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvFunctionModule().cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_function_concrete_info_match(): + world_size = 4 + run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + # test_conv_meta_concrete_info_match() + test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..62fe11e2229b8aabc912b5f0bcad877000bd5a61 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -0,0 +1,111 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + + +class MyModule(nn.Module): + + def __init__(self, in_features=64, out_features=128): + super().__init__() + self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) + + def forward(self, input): + return nn.functional.linear(input, self.fc_weight) + + +def _linear_module_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=1, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_module_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _linear_function_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MyModule().cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=23, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_function_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + # test_linear_module_meta_concrete_info_match() + test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..529686d27d1945ac98a66fb440aaf440e5e2bb97 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -0,0 +1,102 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _adaptiveavgpool_module_mem_test(rank, world_size, port): + """This function is for AdaptiveAvgPool memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target strategies + strategy_number = 1 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_adaptiveavgpool_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _maxpool_module_mem_test(rank, world_size, port): + """This function is for MaxPool memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_maxpool_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_adaptiveavgpool_meta_concrete_info_match() + test_maxpool_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c06f2ee9e202ee10ccfbd2ed01ad7be1c020af0 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -0,0 +1,128 @@ +import copy +from pprint import pprint +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo + + +def mem_test_for_node_strategy(rank: int, + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}): + for strategy_index in range(strategy_number): + # We need to copy the model to avoid do backward more than once in same graph + model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( + input_kwargs) + + tracer = ColoTracer() + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = list(graph.nodes)[node_index] + + # solution construction + # construct the strategy for the target node + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + + # construct the strategy for the output node + placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] + + output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() + if key.type == OperationDataType.OUTPUT) + placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ + output_key] + + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + gm: GraphModule + + num_of_strategies = len(target_node.strategies_vector) + if rank == 0: + print("=======================") + print(f"#strategy_index: {strategy_index + 1}/{num_of_strategies}") + pprint(target_node.strategies_vector[strategy_index]) + + # warmup + with torch.no_grad(): + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + + del output + # forward memory compare + if rank == 0: + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + + if rank == 0: + # print forward memory allocated and peak memory stats in kb + print( + f"forward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb" + ) + + # backward memory compare + grad_tensors = torch.ones_like(output) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output, grad_tensors) + + if rank == 0: + # print backward memory allocated and peak memory stats in kb + print( + f"backward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb" + ) + + # estimated memory + if target_node.op == "call_module": + metainfo = MetaInfo(target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target)) + else: + metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target) + + print("estimated memory:") + print( + f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb" + ) + print( + f"forward temp: {metainfo.memory_cost.fwd.temp / 1024} kb, forward buffer: {metainfo.memory_cost.fwd.buffer / 1024} kb" + ) + print( + f"backward activation: {metainfo.memory_cost.bwd.activation / 1024} kb, backward param: {metainfo.memory_cost.bwd.parameter / 1024} kb" + ) + print( + f"backward temp: {metainfo.memory_cost.bwd.temp / 1024} kb, backward buffer: {metainfo.memory_cost.bwd.buffer / 1024} kb" + ) + print("=======================") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc15e403f35e9b650f0d78a9ccd79f691f42a44 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -0,0 +1,282 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class AddBMMTensorMethodModule(nn.Module): + + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + + def forward(self, bias, x1, x2): + if self.using_kwargs: + output = bias.addbmm(x1, x2, alpha=2, beta=3) + else: + output = bias.addbmm(x1, x2) + return output + + +class AddBMMTorchFunctionModule(nn.Module): + + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + + def forward(self, bias, x1, x2): + if self.using_kwargs: + output = torch.addbmm(bias, x1, x2, alpha=2, beta=3) + else: + output = torch.addbmm(bias, x1, x2) + return output + + +def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module(using_kwargs).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + bias = torch.rand(bias_shape).cuda() + # the index of addbmm node in computation graph + node_index = 3 + # strategy number of addbmm node on 2d device mesh + strategy_number = 7 + # construct input args + input_args = [bias, x1, x2] + # construct meta arg names + meta_arg_names = ['bias', 'x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add + graph = tracer.trace(model, + meta_args={ + 'bias': torch.rand(*bias_shape).to('meta'), + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + gm = ColoGraphModule(model, graph) + + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) + + # build handler + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + for name in strategy_name_list: + print(name) + # one batch dim + assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + + # two batch dim + assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + + # SbSi = SbSi x Sb + assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list + assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + + # SbSj = SbR x SbSj + assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list + assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + + # SbR = SbSk x SbSk + assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list + assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + model = module(using_kwargs).cuda() + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + bias = torch.rand(bias_shape).cuda() + # the index of addbmm node in computation graph + node_index = 3 + # strategy number of addbmm node on 2d device mesh + strategy_number = 1 + # construct input args + input_args = [bias, x1, x2] + # construct meta arg names + meta_arg_names = ['bias', 'x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add + graph = tracer.trace(model, + meta_args={ + 'bias': torch.rand(*bias_shape).to('meta'), + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + gm = ColoGraphModule(model, graph) + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) + + # build handler + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 1 + # one batch dim + assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@pytest.mark.skip("skip due to bias cases not ready") +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) +@rerun_if_address_is_in_use() +def test_2d_device_mesh(module, bias_shape, using_kwargs): + world_size = 4 + run_func = partial(check_2d_device_mesh, + module=module, + bias_shape=bias_shape, + world_size=world_size, + using_kwargs=using_kwargs, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@pytest.mark.skip("skip due to bias cases not ready") +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) +@rerun_if_address_is_in_use() +def test_1d_device_mesh(module, bias_shape, using_kwargs): + world_size = 4 + run_func = partial(check_1d_device_mesh, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_1d_device_mesh() + test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..767864296093cebe9e857793cbd50fa8d1ee4d36 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -0,0 +1,169 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class AddmmModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, m1, m2): + x = torch.addmm(input, m1, m2, beta=3, alpha=2) + return x + + +def check_linear_function_handler(rank, input_shape, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = AddmmModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(input_shape).cuda() + m1 = torch.rand(4, 8).cuda() + m2 = torch.rand(8, 16).cuda() + # the index of addmm node in computation graph + node_index = 4 + # strategy number of linear node + strategy_number = 14 + # construct input args + input_args = [input, m1, m2] + # construct meta arg names + meta_arg_names = ['input', 'm1', 'm2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %m1 : torch.Tensor [#users=1] = placeholder[target=m1] + # %m2 : torch.Tensor [#users=1] = placeholder[target=m2] + # %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + # %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + # return add + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(input_shape).to('meta'), + 'm1': torch.rand(4, 8).to('meta'), + 'm2': torch.rand(8, 16).to('meta'), + }) + gm = ColoGraphModule(model, graph) + # [input_1, m1, m2, addmm, output] + node_list = list(graph.nodes) + linear_node = node_list[4] + strategies_vector = StrategiesVector(linear_node) + + # build handler + handler = LinearFunctionHandler(node=linear_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "m1" + assert mapping['input'].data.shape == torch.Size([4, 8]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8]) + + assert mapping['other'].name == "transpose" + assert mapping['other'].data.shape == torch.Size([16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([8, 16]) + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('m1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('input_shape', [(16,), (4, 16)]) +@rerun_if_address_is_in_use() +def test_addmm_handler(input_shape): + world_size = 4 + run_func_function = partial(check_linear_function_handler, + input_shape=input_shape, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_function, nprocs=world_size) + + +if __name__ == '__main__': + test_addmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab70abffb4cf2e9ef65b860fb61d7ab79c97c52 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -0,0 +1,119 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_bn_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.BatchNorm2d(16)).cuda() + + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16, 64, 64).cuda() + # the index of bn node in computation graph + node_index = 1 + # the total number of bn strategies without sync bn mode + # TODO: add sync bn stategies after related passes ready + strategy_number = 4 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + bn_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(bn_mod_node) + + # build handler + handler = BatchNormModuleHandler(node=bn_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RS = RS x S + assert 'RS0 = RS0 x S0' in strategy_name_list + assert 'RS1 = RS1 x S1' in strategy_name_list + + # RR = RR x R + assert 'RR = RR x R' in strategy_name_list + + # RS01 = RS01 x S01 + assert 'RS01 = RS01 x S01' in strategy_name_list + + # temporarily skip the sync bn test + # TODO: test sync bn after the implicit runtime pass completed + # SR = SR x R WITH SYNC_BN + # assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list + # assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list + + # SS = SS x S WITH SYNC_BN + # assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list + # assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list + + # S01R = S01R x R WITH SYNC_BN + # assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bn_module_handler(): + world_size = 4 + run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_bn_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py new file mode 100644 index 0000000000000000000000000000000000000000..162d1fbba2952f08cc4951f229e1c8577434c244 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -0,0 +1,177 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + +WEIGHT_SHAPE = (32, 16) + + +class LinearModule(torch.nn.Module): + + def __init__(self, weight_shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(*weight_shape)) + self.bias = torch.nn.Parameter(torch.rand(weight_shape[0])) + + def forward(self, x): + x = F.linear(x, self.weight, bias=self.bias) + return x + + +def check_linear_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %weight : [#users=1] = get_attr[target=weight] + # %bias : [#users=1] = get_attr[target=bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) + # return add + graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([64, 16]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(): + world_size = 4 + run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c3f378197e4b65ea719f20be6f47abbdf7ffe9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -0,0 +1,166 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearModule(torch.nn.Module): + + def __init__(self, in_features, out_features, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = self.linear(x) + return x + + +def check_linear_module_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(16, 32, bias=bias).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([64, 16]) + + assert mapping['other'].name == "linear_weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(bias=True): + world_size = 4 + run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9f799533bd693e0fde877a1c0cae04f1413f55 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -0,0 +1,232 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, x1, x2): + out = self.op(x1, x2) + return out + + model = BinaryElementwiseOpModel(op).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 4).cuda() + x2 = torch.rand([4] * other_dim).cuda() + # the index of binary-elementwise node in computation graph + node_index = 2 + # strategy number of binary-elementwise node + strategy_number = 9 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + op_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4] * other_dim) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + # since the dim of the other can change, we make sure at least its last dim sharding is the same + if len(other_sharding_spec.sharding_sequence) == 2: + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + elif len(other_sharding_spec.sharding_sequence) == 1: + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + + +def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op, const): + super().__init__() + self.op = op + self.const = const + + def forward(self, x1): + out = self.op(x1, self.const) + return out + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + model = BinaryElementwiseOpModel(op, other_dim).cuda() + x1 = torch.rand(4, 4).cuda() + # the index of binary-elementwise node in computation graph + node_index = 1 + # strategy number of binary-elementwise node + strategy_number = 9 + # construct input args + input_args = [x1] + # construct meta arg names + meta_arg_names = ['x1'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + meta_args = {'x1': torch.rand(4, 4).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + op_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + +@parameterize('op', [torch.add]) +@parameterize('other_dim', [1, 2]) +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_handler(op, other_dim): + world_size = 4 + run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, + op=op, + other_dim=other_dim, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_tensor, nprocs=world_size) + run_func_int = partial(check_binary_elementwise_handler_with_int, + op=op, + other_dim=other_dim, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_int, nprocs=world_size) + + +if __name__ == '__main__': + test_binary_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..778469df404d2802758b69fce6920e8935dd4c24 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -0,0 +1,219 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class BMMTensorMethodModule(nn.Module): + + def forward(self, x1, x2): + return x1.bmm(x2) + + +class BMMTorchFunctionModule(nn.Module): + + def forward(self, x1, x2): + return torch.bmm(x1, x2) + + +def check_2d_device_mesh(rank, module, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + # the index of bmm node in computation graph + node_index = 2 + # strategy number of bmm node on 2d device mesh + strategy_number = 7 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = BMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one batch dim + assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + + # two batch dim + assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + + # SbSi = SbSi x Sb + assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list + assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + + # SbSj = SbR x SbSj + assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list + assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + + # SbR = SbSk x SbSk + assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list + assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +def check_1d_device_mesh(rank, module, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + # the index of bmm node in computation graph + node_index = 2 + # strategy number of bmm node on 1d device mesh + strategy_number = 1 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + gm = ColoGraphModule(model, graph) + linear_mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = BMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 1 + # one batch dim + assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bmm_handler(module): + world_size = 4 + run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port()) + mp.spawn(run_func_2d, nprocs=world_size) + run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port()) + mp.spawn(run_func_1d, nprocs=world_size) + + +if __name__ == '__main__': + test_bmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..2acd015c8f5907d63c327e4f5ace482a3dfd6518 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -0,0 +1,319 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_conv_module_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + input = torch.rand(4, 4, 64, 64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of conv node in computation graph + node_index = 1 + # total number of conv strategies + strategy_number = 16 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + conv_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = ConvModuleHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + # assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['other'].name == "weight" + # assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + + if bias: + assert mapping['bias'].name == "bias" + # assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + # assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:] + assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] + + +class ConvModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.conv2d(input, others, bias=bias, padding=1) + return x + + +def check_conv_function_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 64, 64).cuda() + others = torch.rand(16, 4, 3, 3).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of conv strategies + strategy_number = 16 + node_index = 2 + if bias: + bias_tensor = torch.rand(16).cuda() + input_kwargs['bias'] = bias_tensor + node_index += 1 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs) + + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %others : torch.Tensor [#users=1] = placeholder[target=others] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) + # return conv2d + meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')} + if bias: + meta_args['bias'] = torch.rand(16).to('meta') + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + if bias: + conv_mod_node = list(graph.nodes)[3] + else: + conv_mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = ConvFunctionHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['other'].name == "others" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "conv2d" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:] + assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +# We temporarily ban the bias option before doing bias add +# before all reduce communication may encounter correctness issue. +# @parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_module_handler(bias=False): + world_size = 4 + run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +# We temporarily ban the bias option before doing bias add +# before all reduce communication may encounter correctness issue. +# @parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_function_handler(bias=False): + world_size = 4 + run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_conv_module_handler() + test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5bce383dd0ab33c9179682cab5d824fce2c12c56 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -0,0 +1,286 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import ( + EmbeddingFunctionHandler, + EmbeddingModuleHandler, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + +NUM_EMBEDDINGS = 16 +EMBEDDING_DIMS = 32 + + +class EmbeddingModule(nn.Module): + + def __init__(self, num_embeddings, embedding_dims): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dims) + + def forward(self, input): + x = self.embedding(input) + return x + + +def check_embedding_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %embedding : [#users=1] = call_module[target=embedding](args = (%input_1,), kwargs = {}) + # return embedding + input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS + input = input.to(torch.int64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of embedding node in computation graph + node_index = 1 + # total number of embedding strategies + strategy_number = 19 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + embedding_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(embedding_node) + + # build handler + handler = EmbeddingModuleHandler(node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + # assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([1024]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + + assert mapping['output'].name == "embedding" + assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RR = RR x RR + assert 'RR = R x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0 x RR_0' in strategy_name_list + assert 'S0R = S0 x RR_1' in strategy_name_list + assert 'S0R = S0 x RR_2' in strategy_name_list + assert 'S1R = S1 x RR_0' in strategy_name_list + assert 'S1R = S1 x RR_1' in strategy_name_list + assert 'S1R = S1 x RR_2' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0 x RS1_0' in strategy_name_list + assert 'S0S1 = S0 x RS1_1' in strategy_name_list + assert 'S0S1 = S0 x RS1_2' in strategy_name_list + assert 'S1S0 = S1 x RS0_0' in strategy_name_list + assert 'S1S0 = S1 x RS0_1' in strategy_name_list + assert 'S1S0 = S1 x RS0_2' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = R x RS0' in strategy_name_list + assert 'RS1 = R x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01 x RR_0' in strategy_name_list + assert 'S01R = S01 x RR_1' in strategy_name_list + assert 'S01R = S01 x RR_2' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = R x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] + + +class EmbeddingFunction(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others): + x = nn.functional.embedding(input, others) + return x + + +def check_embedding_function_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = EmbeddingFunction().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS + input = input.to(torch.int64).cuda() + others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of embedding strategies + strategy_number = 19 + node_index = 2 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %others : torch.Tensor [#users=1] = placeholder[target=others] + # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) + # return embedding + meta_args = { + "input": torch.rand(4, 16, 16).to('meta'), + "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + embedding_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(embedding_node) + + # build handler + handler = EmbeddingFunctionHandler(node=embedding_node, + device_mesh=device_mesh, + strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([1024]) + + assert mapping['other'].name == "others" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + + assert mapping['output'].name == "embedding" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RR = RR x RR + assert 'RR = R x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0 x RR_0' in strategy_name_list + assert 'S0R = S0 x RR_1' in strategy_name_list + assert 'S0R = S0 x RR_2' in strategy_name_list + assert 'S1R = S1 x RR_0' in strategy_name_list + assert 'S1R = S1 x RR_1' in strategy_name_list + assert 'S1R = S1 x RR_2' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0 x RS1_0' in strategy_name_list + assert 'S0S1 = S0 x RS1_1' in strategy_name_list + assert 'S0S1 = S0 x RS1_2' in strategy_name_list + assert 'S1S0 = S1 x RS0_0' in strategy_name_list + assert 'S1S0 = S1 x RS0_1' in strategy_name_list + assert 'S1S0 = S1 x RS0_2' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = R x RS0' in strategy_name_list + assert 'RS1 = R x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01 x RR_0' in strategy_name_list + assert 'S01R = S01 x RR_1' in strategy_name_list + assert 'S01R = S01 x RR_2' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = R x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_embedding_module_handler(): + world_size = 4 + run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_embedding_function_handler(): + world_size = 4 + run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_embedding_module_handler() + test_embedding_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ad093c2edf43481eade187b9bef1ac781d0d41dc --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer + + +class GetattrModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) + + def forward(self, input): + weight = self.conv.weight + return weight + + +def test_getattr_handler(): + model = GetattrModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # return conv_weight + graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + getattr_node = list(graph.nodes)[1] + getattr_strategies_vector = StrategiesVector(getattr_node) + + # build handler + getattr_handler = GetattrHandler(node=getattr_node, + device_mesh=device_mesh, + strategies_vector=getattr_strategies_vector) + + getattr_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping + mapping = getattr_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "conv_weight" + assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in getattr_handler.strategies_vector] + assert "Replica Attribute" in strategy_name_list + + +if __name__ == '__main__': + test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..4e01ed2437774ed97567a07819728d51dfcf7c8b --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -0,0 +1,167 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class GetItemFromTensorModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + x = conv_node[1] + return x + + +def test_getitem_from_tensor_handler(): + model = GetItemFromTensorModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%conv2d, 1), kwargs = {}) + # return getitem + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + "other": torch.rand(4, 16, 3, 3).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + getitem_mod_node = list(graph.nodes)[3] + getitem_strategies_vector = StrategiesVector(getitem_mod_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_mod_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + + getitem_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) + + assert mapping['index'].name == "index" + assert isinstance(mapping['index'].data, int) + assert mapping['index'].type == OperationDataType.ARG + + assert mapping['output'].name == "getitem" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 62, 62]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(conv_strategies_vector) + + +class GetItemFromTupleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + split_node = torch.split(input, 2, 0) + x = split_node[1] + return x + + +def test_getitem_from_tuple_handler(): + model = GetItemFromTupleModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) + # return getitem + graph = tracer.trace(model, meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + input_node = list(graph.nodes)[0] + split_node = list(graph.nodes)[1] + getitem_node = list(graph.nodes)[2] + input_strategies_vector = StrategiesVector(input_node) + getitem_strategies_vector = StrategiesVector(getitem_node) + split_strategies_vector = StrategiesVector(split_node) + + # build handler + input_handler = PlacehodlerHandler( + node=input_node, + device_mesh=device_mesh, + strategies_vector=input_strategies_vector, + placeholder_option='replicated', + ) + input_handler.register_strategy(compute_resharding_cost=False) + setattr(input_node, 'strategies_vector', input_strategies_vector) + split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + split_handler.register_strategy(compute_resharding_cost=False) + setattr(split_node, 'strategies_vector', split_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + getitem_handler.register_strategy(compute_resharding_cost=False) + setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "split" + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) + + assert mapping['index'].name == "index" + assert isinstance(mapping['index'].data, int) + assert mapping['index'].type == OperationDataType.ARG + + assert mapping['output'].name == "getitem" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(split_strategies_vector) + + +if __name__ == '__main__': + test_getitem_from_tensor_handler() + test_getitem_from_tuple_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d0063fd6b69717a0621160577fbc44da32d50d --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -0,0 +1,109 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_ln_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.LayerNorm(16)).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16).cuda() + # the index of bn node in computation graph + node_index = 1 + # the total number of ln strategies + strategy_number = 4 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['input'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + + ln_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(ln_mod_node) + + # build handler + handler = LayerNormModuleHandler(node=ln_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size([4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.shape == torch.Size([4, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SR = SR x R + assert '[S0, R] = [S0, R] x [R]' in strategy_name_list + assert '[S1, R] = [S1, R] x [R]' in strategy_name_list + + # RR = RR x R + assert 'RR = RR x R' in strategy_name_list + + # S01R = S01R x R + assert '[S01, R] = [S01, R] x [R]' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_ln_module_handler(): + world_size = 4 + run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_ln_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..fb8821fae58bb30d610377716e52625ed669d2ee --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -0,0 +1,332 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_linear_module_handler(rank, bias, input_shape, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(input_shape).cuda() + # the index of linear node in computation graph + node_index = 1 + # strategy number of linear node + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['input'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')}) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == input_logical_shape + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([32]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([32]) + + assert mapping['output'].name == "_0" + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.linear(input, others, bias=bias) + return x + + +def check_linear_function_handler(rank, bias, input_shape, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(input_shape).cuda() + other = torch.rand(32, 16).cuda() + # the index of linear node in computation graph + node_index = 2 + # strategy number of linear node + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 + # construct input args + input_args = [input, other] + # construct meta arg names + meta_arg_names = ['input', 'others'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(input_shape).to('meta'), + 'others': torch.rand(32, 16).to('meta') + }) + gm = ColoGraphModule(model, graph) + if bias: + linear_func_node = list(graph.nodes)[3] + else: + linear_func_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_func_node) + + # build handler + handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == torch.Size(input_logical_shape) + + assert mapping['other'].name == "others" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([32]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert mapping['output'].name == "linear" + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(input_shape, bias=False): + world_size = 4 + run_func_module = partial(check_linear_module_handler, + bias=bias, + input_shape=input_shape, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + run_func_function = partial(check_linear_function_handler, + bias=bias, + input_shape=input_shape, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_function, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..306c45f56dbfd1394f0c8ff81feeab4941191294 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import ( + MatMulHandler, + MatMulType, + _get_bmm_logical_shape, + get_matmul_type, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.utils import parameterize + + +class MatMulModule(nn.Module): + + def forward(self, x1, x2): + return torch.matmul(x1, x2) + + +@parameterize( + 'tensor_shapes', + [ + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ]) +def test_matmul_node_handler(tensor_shapes): + input_shape, other_shape = tensor_shapes + + # get output shape + x1 = torch.rand(*input_shape) + x2 = torch.rand(*other_shape) + output_shape = list(torch.matmul(x1, x2).shape) + + # get matmul type + matmul_type = get_matmul_type(x1.dim(), x2.dim()) + + model = MatMulModule() + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + print(graph) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(mod_node) + + # build handler + handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + logical_input_shape = input_shape + logical_other_shape = other_shape + logical_output_shape = output_shape + if matmul_type == MatMulType.MM and len(input_shape) == 1: + logical_input_shape = [1] + input_shape + elif matmul_type == MatMulType.BMM: + logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( + input_shape, other_shape, handler.transforms) + else: + logical_input_shape = input_shape + + # check input operation data + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + + # check other operation data + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size(other_shape) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + + # check output + assert mapping['output'].name == "matmul" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # ensure there is no duplicate strategy + if matmul_type != MatMulType.BMM: + assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + + if matmul_type == MatMulType.DOT: + # dot product will produce a scaler + # results should fulfill: + # 1. the input and other operands have the same sharding spec + # 2. the output has no sharding + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + assert len(output_sharding_spec.sharding_sequence) == 0 + elif matmul_type == MatMulType.MV: + # matrix-vector product should fulfill + # 1. the last dim of the input and other operands should have the same sharding + # 2. the first dim of the input and other should have the same sharding + # 3. the output should have only 1 dim + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert len(output_sharding_spec.sharding_sequence) == 1 + elif matmul_type == MatMulType.MM: + # matrix-matrix multiplication should fulfil + # 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding + # 2. the input's last dim and the first dim of the other should have the same sharding + # 3. the last dim of the output and other should have the same sharding + # 4. the input and output should have the same number of dims + if len(input_shape) == 2: + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0] + assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence) + elif matmul_type == MatMulType.BMM: + # bmm should fulfil + # 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding + # 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding + # 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding + if len(other_shape) > 1: + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + if len(input_shape) > 1: + assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2] + if len(other_shape) > 2: + assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] + + +if __name__ == '__main__': + test_matmul_node_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d47876af2a2d93f39e3dc69d0e6d7ca0faf6b630 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -0,0 +1,58 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \ + NormPoolingHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +def test_norm_pool_handler(): + model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) + + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = NormPoolingHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 9 + + +if __name__ == '__main__': + test_norm_pool_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..16eb983000191fae8d4915c2865802e2e499fe51 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use + + +class OutputModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + y = x * 2 + return x, y + + +@parameterize('output_option', ['distributed', 'replicated']) +@rerun_if_address_is_in_use() +def test_output_handler(output_option): + model = OutputModel() + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=2] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # return (x, mul) + graph = tracer.trace(model, meta_args={ + "x": torch.rand(4, 4, 64, 64).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + output_node = list(graph.nodes)[2] + output_strategies_vector = StrategiesVector(output_node) + + # build handler + otuput_handler = OuputHandler(node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option) + + otuput_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping + mapping = otuput_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "output" + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in otuput_handler.strategies_vector] + if output_option == 'distributed': + assert "Distributed Output" in strategy_name_list + else: + assert "Replica Output" in strategy_name_list + + +if __name__ == '__main__': + test_output_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c695b8843a3c2a82110a20a1007fffb558bd2d21 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -0,0 +1,339 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + # permute_node = torch.permute(conv_node, self.permute_dims) + if self.call_function == torch.permute: + permute_node = self.call_function(conv_node, self.reshape_dims) + else: + permute_node = self.call_function(conv_node, *self.reshape_dims) + return permute_node + + +class LinearReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + # permute_node = torch.permute(linear_node, self.tgt_shape) + if self.call_function == torch.permute: + permute_node = self.call_function(linear_node, self.reshape_dims) + else: + permute_node = self.call_function(linear_node, *self.reshape_dims) + return permute_node + + +def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if call_function == torch.permute: + reshape_dims = reshape_dims[0] + elif call_function == torch.transpose: + reshape_dims = reshape_dims[1] + model = model_cls(reshape_dims, call_function).cuda() + + if model_cls.__name__ == 'ConvReshapeModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearReshapeModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) + # return permute + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return permute + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(reshape_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvReshapeModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearReshapeModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if call_function == torch.permute: + reshape_handler = PermuteHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + else: + reshape_handler = TransposeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + + reshape_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvReshapeModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + if call_function == torch.permute: + assert mapping['output'].name == "permute" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + else: + assert mapping['output'].name == "transpose" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + if rank == 0: + for name in strategy_name_list: + print(name) + if model_cls.__name__ == 'ConvReshapeModel': + + if reshape_dims in ((0, 2, 1, 3), (1, 2)): + assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearReshapeModel': + + if reshape_dims == ((0, 2, 1, 3), (1, 2)): + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('call_function', [torch.permute, torch.transpose]) +@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +def test_view_handler(call_function, reshape_dims, model_cls): + world_size = 4 + run_func = partial(check_view_handler, + call_function=call_function, + reshape_dims=reshape_dims, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0aafb9e0b1c86ffb96b5c713e77a6b4d7ec385f3 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use + + +class PlaceholderModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + return input + + +@parameterize('placeholder_option', ['distributed', 'replicated']) +@rerun_if_address_is_in_use() +def test_placeholder_handler(placeholder_option): + model = PlaceholderModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # return input_1 + graph = tracer.trace(model, meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + placeholder_node = list(graph.nodes)[0] + placeholder_strategies_vector = StrategiesVector(placeholder_node) + # build handler + placeholder_handler = PlacehodlerHandler(node=placeholder_node, + device_mesh=device_mesh, + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option) + + placeholder_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = placeholder_handler.get_operation_data_mapping() + + strategy = placeholder_strategies_vector[0] + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + + if placeholder_option == 'distributed': + assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + else: + assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "input_1" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] + if placeholder_option == 'replicated': + assert "Replica Placeholder" in strategy_name_list + else: + assert "Distributed Placeholder" in strategy_name_list + + +if __name__ == '__main__': + test_placeholder_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..613f8f3d0ae4ff0064df28b22570fa777f59cd1c --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ReshapeModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + reshape_node = conv_node.view(2, -1) + return reshape_node + + +def test_reshape_handler(): + model = ReshapeModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + "other": torch.rand(4, 16, 3, 3).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + reshape_strategies_vector = StrategiesVector(reshape_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + reshape_handler = ReshapeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=reshape_strategies_vector) + + reshape_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([2, 30752]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(reshape_strategies_vector) == len(conv_strategies_vector) + + +if __name__ == '__main__': + test_reshape_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e8e32778be6ba1784467424b18467705657eea --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -0,0 +1,186 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSplitModel(nn.Module): + + def __init__(self, softmax_dim): + super().__init__() + self.softmax_dim = softmax_dim + + def forward(self, input, other): + linear_node = F.linear(input, other, bias=None) + softmax_node = F.softmax(linear_node, self.softmax_dim) + return softmax_node + + +def check_split_handler(rank, softmax_dim, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(softmax_dim=softmax_dim).cuda() + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + softmax_handler = SoftmaxHandler(node=split_node, + device_mesh=device_mesh, + strategies_vector=split_strategies_vector) + + softmax_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = softmax_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['softmax_dim'].name == "softmax_dim" + assert mapping['softmax_dim'].data == softmax_dim + assert mapping['softmax_dim'].type == OperationDataType.ARG + + assert mapping['output'].name == "softmax" + assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + + if softmax_dim == 0: + assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if softmax_dim == 1: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('softmax_dim', [0, 1, 2, 3]) +@parameterize('model_cls', [LinearSplitModel]) +def test_split_handler(softmax_dim, model_cls): + world_size = 4 + run_func = partial(check_split_handler, + softmax_dim=softmax_dim, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8e905c54a28840bfcce708ca3bfa6ac0473512 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -0,0 +1,270 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvSplitModel(nn.Module): + + def __init__(self, split_size, split_dim): + super().__init__() + self.split_size = split_size + self.split_dim = split_dim + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + split_node = conv_node.split(self.split_size, dim=self.split_dim) + return split_node + + +class LinearSplitModel(nn.Module): + + def __init__(self, split_size, split_dim): + super().__init__() + self.split_size = split_size + self.split_dim = split_dim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + split_node = linear_node.split(self.split_size, dim=self.split_dim) + return split_node + + +def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(split_size=split_size, split_dim=split_dim).cuda() + + if model_cls.__name__ == 'ConvSplitModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearSplitModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvSplitModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearSplitModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvSplitModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearSplitModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + + split_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = split_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvSplitModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "split" + split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim) + assert mapping['output'].logical_shape == tuple([item.shape for item in split_items]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + for name in strategy_name_list: + print(name) + if model_cls.__name__ == 'ConvSplitModel': + + if split_dim == 0: + assert '[R, S1, R, R]_0' in strategy_name_list + assert '[R, S0, R, R]_1' in strategy_name_list + assert '[R, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, R]_4' in strategy_name_list + assert '[R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R]_6' in strategy_name_list + assert '[R, S0, R, R]_7' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R]_10' in strategy_name_list + assert '[R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, R]_12' in strategy_name_list + assert '[R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R]_15' in strategy_name_list + + if split_dim == 1: + assert '[S0, R, R, R]_0' in strategy_name_list + assert '[S1, R, R, R]_1' in strategy_name_list + assert '[S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R]_5' in strategy_name_list + assert '[R, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R]_10' in strategy_name_list + assert '[R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearSplitModel': + + if split_dim == 0: + assert '[R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1]_2' in strategy_name_list + assert '[R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0]_5' in strategy_name_list + assert '[R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R]_7' in strategy_name_list + assert '[R, R, S0, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R]_10' in strategy_name_list + assert '[R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R]_19' in strategy_name_list + assert '[R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01]_22' in strategy_name_list + + if split_dim == 1: + assert '[S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('split_size', [2]) +@parameterize('split_dim', [0, 1, 2]) +@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) +def test_split_handler(split_size, split_dim, model_cls): + world_size = 4 + run_func = partial(check_split_handler, + split_size=split_size, + split_dim=split_dim, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5fda4de1a101301fac99760f427785d6b6ba61ad --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -0,0 +1,235 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSumModel(nn.Module): + + def __init__(self, sum_dims, keepdim): + super().__init__() + self.sum_dims = sum_dims + self.keepdim = keepdim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + if self.sum_dims is not None: + sum_node = torch.sum(linear_node, self.sum_dims, keepdim=self.keepdim) + else: + sum_node = torch.sum(linear_node, keepdim=self.keepdim) + return sum_node + + +def check_sum_handler(rank, sum_dims, keepdim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 24 + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + + tracer = ColoTracer() + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) + # return sum_1 + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + sum_node = list(graph.nodes)[3] + sum_strategies_vector = StrategiesVector(sum_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) + + sum_handler.register_strategy(compute_resharding_cost=False) + + # sum handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(sum_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in sum_strategies_vector] + + # check operation data mapping + mapping = sum_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "sum_1" + sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape + assert mapping['output'].logical_shape == sum_node_shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # check strategy name + if sum_dims == (0, 2) and keepdim == False: + assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + + if sum_dims == (0, 2) and keepdim == True: + assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == False: + assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == True: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('sum_dims', [(0, 2), 1]) +@parameterize('keepdim', [False, True]) +def test_sum_handler(sum_dims, keepdim): + world_size = 4 + run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sum_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..0c67abc7da61a00511f54d78a487dfcbde1a5fe5 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer + + +class TensorConstructorModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + arange_node = torch.arange(x.size()[0]) + x = x + arange_node + return x + + +def test_where_handler(): + model = TensorConstructorModel() + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=2] = placeholder[target=x] + # %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%size, 0), kwargs = {}) + # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) + # return add + graph = tracer.trace(model, meta_args={ + "x": torch.rand(10).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + arange_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(arange_node) + + # build handler + handler = TensorConstructorHandler(node=arange_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['output'].name == "arange" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([10]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + assert 'Replica Tensor Constructor' in strategy_name_list + + +if __name__ == '__main__': + test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d12cd12ffda9a6fb4637827d88257004bca9aa --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ReLuModel(nn.Module): + + def __init__(self): + super().__init__() + self.act = torch.nn.ReLU() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + relu_node = self.act(conv_node) + return relu_node + + +def test_elementwise_handler(): + model = ReLuModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) + # return act + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + "other": torch.rand(4, 16, 3, 3).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + relu_mod_node = list(graph.nodes)[3] + relu_strategies_vector = StrategiesVector(relu_mod_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + relu_handler = UnaryElementwiseHandler(node=relu_mod_node, + device_mesh=device_mesh, + strategies_vector=relu_strategies_vector) + + relu_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = relu_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) + + assert mapping['output'].name == "act" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 62, 62]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(relu_strategies_vector) == len(conv_strategies_vector) + + +if __name__ == '__main__': + test_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..08a702789f9f051d1439ae012049b5003a9fadf9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -0,0 +1,265 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvViewModel(nn.Module): + + def __init__(self, tgt_shape): + super().__init__() + self.tgt_shape = tgt_shape + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + reshape_node = conv_node.view(*self.tgt_shape) + return reshape_node + + +class LinearViewModel(nn.Module): + + def __init__(self, tgt_shape): + super().__init__() + self.tgt_shape = tgt_shape + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + reshape_node = linear_node.view(*self.tgt_shape) + return reshape_node + + +def check_view_handler(rank, tgt_shape, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(tgt_shape).cuda() + + if model_cls.__name__ == 'ConvViewModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearViewModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvViewModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearViewModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + view_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(view_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvViewModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearViewModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) + + view_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = view_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvViewModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(tgt_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + + if model_cls.__name__ == 'ConvViewModel': + + if tgt_shape == (32, 4, 64, 16, 4): + assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list + assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + + if tgt_shape == (8, 4, 4, 64, 16, 4): + assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearViewModel': + + if tgt_shape == (32, 4, 64, 16, 4): + assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list + + if tgt_shape == (8, 4, 4, 64, 16, 4): + assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) +@parameterize('model_cls', [ConvViewModel, LinearViewModel]) +def test_view_handler(tgt_shape, model_cls): + world_size = 4 + run_func = partial(check_view_handler, + tgt_shape=tgt_shape, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..9838e2eb01c6539553a6e42dc1c762ce9064b5c5 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \ + WhereHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.tracer.meta_patch.patched_module import linear + + +class ConvModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, condition, x, y): + output = torch.where(condition, x, y) + return output + + +def test_where_handler(): + model = ConvModel() + tracer = ColoTracer() + # graph(): + # %condition : torch.Tensor [#users=1] = placeholder[target=condition] + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %y : torch.Tensor [#users=1] = placeholder[target=y] + # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) + # return where + graph = tracer.trace(model, + meta_args={ + "condition": torch.rand(4, 4, 64, 64).to('meta'), + "x": torch.rand(4, 1, 64, 64).to('meta'), + "y": torch.rand(1, 4, 64, 64).to('meta') + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + where_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(where_node) + + # build handler + handler = WhereHandler(node=where_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping, _ = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['condition'].name == "condition" + assert mapping['condition'].data.is_meta + assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['condition'].type == OperationDataType.ARG + assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['x'].name == "x" + assert mapping['x'].data.is_meta + assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64]) + assert mapping['x'].type == OperationDataType.ARG + assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['y'].name == "y" + assert mapping['y'].data.is_meta + assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64]) + assert mapping['y'].type == OperationDataType.ARG + assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['output'].name == "where" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + # 4*3 + 4*3/2*2 + 1 + assert len(strategy_name_list) == 25 + + +if __name__ == '__main__': + test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9a625a4801611bb19a79698f7fa7bba4392620 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -0,0 +1,178 @@ +import copy +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import to_global +from colossalai.testing.comparison import assert_close + + +def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): + + model_to_compare = copy.deepcopy(model) + args_to_compare = [] + kwargs_to_compare = {} + for arg_index, input_tensor in enumerate(input_args): + + def wrapper(param, index): + + def hook_fn(grad): + grad_dict[index] = grad + + param.register_hook(hook_fn) + + arg_to_compare = copy.deepcopy(input_tensor) + + # only Tensors of floating point and complex dtype can require gradients + if arg_to_compare.dtype != torch.int64: + arg_to_compare.requires_grad = True + wrapper(arg_to_compare, arg_index) + + args_to_compare.append(arg_to_compare) + + for name, input_kwarg in input_kwargs.items(): + + def wrapper(param, name): + + def hook_fn(grad): + grad_dict[name] = grad + + param.register_hook(hook_fn) + + kwarg_to_compare = copy.deepcopy(input_kwarg) + + # only Tensors of floating point and complex dtype can require gradients + if kwarg_to_compare.dtype != torch.int64: + kwarg_to_compare.requires_grad = True + wrapper(kwarg_to_compare, name) + + kwargs_to_compare[name] = kwarg_to_compare + + return model_to_compare, args_to_compare, kwargs_to_compare + + +def numerical_test_for_node_strategy(model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = 'normal'): + for strategy_index in range(strategy_number): + print(f'#strategy_index: {strategy_index}') + # We need to copy the model to avoid do backward more than once in same graph + grad_to_compare_dict = {} + grad_to_shard_dict = {} + model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_compare_dict) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, + grad_to_shard_dict) + + tracer = ColoTracer() + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = list(graph.nodes)[node_index] + if node_type == 'normal': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + elif node_type == 'following': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + solution[node_index + 1] = strategy_index + else: + node_vector = strategies_constructor.leaf_strategies[node_index] + strategy_to_keep = node_vector[strategy_index] + node_vector = [strategy_to_keep] + # solution construction + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + + # forward result compare + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) + assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output') + + # backward result compare + if isinstance(output, (tuple, list)): + loss = output[0].sum() + loss_to_compare = output_to_compare[0].sum() + else: + loss = output.sum() + loss_to_compare = output_to_compare.sum() + + loss_to_compare.backward() + loss.backward() + for key in grad_to_shard_dict.keys(): + grad_to_shard = grad_to_shard_dict[key] + grad_to_compare = grad_to_compare_dict[key] + assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad') + # extract the strategy used in this iter + strategy_in_use = target_node.strategies_vector[strategy_index] + param_to_shard_dict = dict(gm.named_parameters()) + param_to_compare_dict = dict(model_to_compare.named_parameters()) + for name in param_to_shard_dict.keys(): + param_name = name.split('.')[-1] + if node_type == 'normal': + param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) + else: + if 'weight' in name: + param_sharding_spec = list(graph.nodes)[4].sharding_spec + elif 'bias' in name: + param_sharding_spec = list(graph.nodes)[5].sharding_spec + + grad_sharded = param_to_shard_dict[name].grad + grad_to_compare = param_to_compare_dict[name].grad + global_grad = to_global(grad_sharded, param_sharding_spec) + assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad') + + +def assert_close_helper(first: torch.Tensor, + second: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + strategy_index: int = -1, + type: str = 'not defined'): + """ + This method is used to check whether the average difference between two tensors is as close as expected. + """ + try: + if isinstance(first, (tuple, list)): + for first_element, second_element in zip(first, second): + assert_close(first_element, second_element, rtol=rtol, atol=atol) + else: + assert_close(first, second, rtol=rtol, atol=atol) + except: + print(f'strategy index {strategy_index} encounter assert_close error on {type}') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py new file mode 100644 index 0000000000000000000000000000000000000000..611402fe8394fdc5c15f78e0155ad51c989ef3a2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -0,0 +1,128 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer + + +def _param_resharding_cost_assertion(node): + for strategy in node.strategies_vector: + for prev_node, resharding_cost in strategy.resharding_costs.items(): + if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM: + for cost in resharding_cost: + assert cost.fwd == 0 + assert cost.bwd == 0 + assert cost.total == 0 + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def test_linear_module(): + model = LinearModel(4, 8) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + linear_node = node_list[3] + _param_resharding_cost_assertion(linear_node) + + +def test_conv_module(): + model = ConvModel(3, 6, 2) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + node_list = list(graph.nodes) + conv_node = node_list[3] + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + _param_resharding_cost_assertion(conv_node) + + +if __name__ == '__main__': + test_linear_module() + test_conv_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..814edd27948cc3db6a6868d5ccea8a3215d166db --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py @@ -0,0 +1,270 @@ +import copy +from copy import deepcopy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.fx import GraphModule +from torchvision.models import resnet34, resnet50 + +from colossalai import device +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.constants import * +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + +seed = 128 +cudnn_benchmark = False +cudnn_deterministic = True + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample=None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer=None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.relu(out) + + return out + + +def check_apply_bottleneck(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + input = torch.rand(4, 4, 4, 4).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + tracer = ColoTracer() + model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda() + test_model = copy.deepcopy(model) + test_input = copy.deepcopy(input) + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) + # %conv2 : [#users=1] = call_module[target=conv2](args = (%relu,), kwargs = {}) + # %bn2 : [#users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {}) + # %relu_1 : [#users=1] = call_module[target=relu](args = (%bn2,), kwargs = {}) + # %conv3 : [#users=1] = call_module[target=conv3](args = (%relu_1,), kwargs = {}) + # %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {}) + # %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {}) + # return relu_2 + input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + print(solution) + for index, node in enumerate(graph.nodes): + print(node.name, node.strategies_vector[solution[index]].name) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + nodes = [node for node in gm.graph.nodes] + # TODO: wrap the gm to avoid the influence of the user training code + cuda_rng_state = torch.cuda.get_rng_state() + origin_output = test_model(test_input) + torch.cuda.set_rng_state(cuda_rng_state) + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + + assert output.shape == origin_output.shape + assert_close(output, origin_output, rtol=1e-03, atol=1e-05) + print("*******************backward starting*******************") + cuda_rng_state = torch.cuda.get_rng_state() + output.sum().backward() + torch.cuda.set_rng_state(cuda_rng_state) + origin_output.sum().backward() + if rank == 0: + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 0, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) + + if rank == 1: + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 1, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) + + if rank == 2: + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 2, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) + + if rank == 3: + print( + f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()}" + ) + print( + f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}" + ) + print( + f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}" + ) + print( + f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 3, 1)).abs().sum()}" + ) + print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}") + + assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum()) + assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum()) + assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum()) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_apply(): + world_size = 4 + run_func = partial(check_apply_bottleneck, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_apply() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..66cd3f3f770708b228bf4012c350c4d3fe76655b --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py @@ -0,0 +1,107 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x = self.conv(x) + x = torch.flatten(x) + return x + + +def check_apply(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + input = torch.rand(4, 4, 4, 4).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + tracer = ColoTracer() + model = ConvModel(4, 4).cuda() + test_model = copy.deepcopy(model) + test_input = copy.deepcopy(input) + + input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + nodes = [node for node in gm.graph.nodes] + # TODO: wrap the gm to avoid the influence of the user training code + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + origin_output = test_model(test_input) + assert output.equal(origin_output) + origin_loss = origin_output.sum() + loss = output.sum() + + origin_loss.backward() + loss.backward() + + grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2) + grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2) + + if rank in (0, 1): + assert_close(gm.conv.weight.grad.data, grad_0.data) + elif rank in (2, 3): + assert_close(gm.conv.weight.grad.data, grad_1.data) + + +# skip this test due to pulp not installed in CI environment +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_apply(): + world_size = 4 + run_func = partial(check_apply, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_apply() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py new file mode 100644 index 0000000000000000000000000000000000000000..82accebdb032ddec67289fb03b8175551e06b710 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py @@ -0,0 +1,330 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import transformers +from torch.fx import GraphModule +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2MLP, + BaseModelOutputWithPastAndCrossAttentions, + GPT2PreTrainedModel, +) +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import parameterize +from colossalai.testing.pytest_wrapper import run_on_environment_flag + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.scale_attn_weights = config.scale_attn_weights + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + qkv = self.c_attn(hidden_states) + + # query = self._split_heads(query, self.num_heads, self.head_dim) + # key = self._split_heads(key, self.num_heads, self.head_dim) + # value = self._split_heads(value, self.num_heads, self.head_dim) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + present = (key, value) + + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPT2Block(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + # %transformer_h_0_ln_1 + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + # add_2 + hidden_states = inputs_embeds + position_embeds + + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + # transformer_drop + hidden_states = self.drop(hidden_states) + # comment to run pipeline + # add_3 + output_shape = input_shape + (hidden_states.size(-1),) + + presents = None + all_self_attentions = None + all_cross_attentions = None + all_hidden_states = None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs[0] + + hidden_states = self.ln_f(hidden_states) + # comment to run pipeline + hidden_states = hidden_states.view(output_shape) + + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +def test_self_attention_block(model_cls): + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config) + else: + model = model_cls(config=config) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + if model_cls == GPT2MLP: + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + } + else: + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + input_sample = {k: v.to('meta') for k, v in kwargs.items()} + + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + print(gm.graph) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_self_attention_block() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a5ae7ac1c01d684d7adc9a810d7919784b6533 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -0,0 +1,98 @@ +import torch +from torch.fx import GraphModule +from torchvision.models import resnet50 + +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = resnet50(num_classes=100000) + input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) + # %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {}) + # %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {}) + # %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {}) + # %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {}) + # %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {}) + # %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {}) + # %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {}) + # %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {}) + # %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {}) + # %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {}) + # %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {}) + # %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {}) + # %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {}) + # ... + # %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {}) + # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {}) + # %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) + # return fc + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + + ret = solver.call_solver_serialized_args() + print(ret[0]) + print(solver.last_s_val) + strategies_list = solver.last_s_val + + computation_cost = 0 + communication_cost = 0 + communication_cost_bn = 0 + memory_cost = 0 + for index, node in enumerate(graph.nodes): + if node.op == 'call_module': + submod = node.graph.owning_module.get_submodule(node.target) + if type(submod) in BATCHNORM_MODULE_OP: + communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + print(f'bn communication cost is {communication_cost_bn}') + + +if __name__ == '__main__': + test_cost_graph() diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..1520d605404312a1fc1b5d4d3978c73d052e2401 --- /dev/null +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -0,0 +1,54 @@ +from functools import partial +from typing import List + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.logging import disable_existing_loggers + +disable_existing_loggers() +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=world_size)) +torch.manual_seed(123) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) + rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if rank == 0: + obj = [torch.randn(3,)] + _send_object(obj, 1) + + if rank == 1: + _recv_object(0) + + if rank == 2: + _recv_object(3) + + if rank == 3: + obj = [torch.randn(3,)] + _send_object(obj, 2) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + disable_existing_loggers() + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_object_list_p2p() diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py new file mode 100644 index 0000000000000000000000000000000000000000..07cb67730d24d4169705951ccbc5577760e7cda4 --- /dev/null +++ b/tests/test_comm/test_comm.py @@ -0,0 +1,75 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication import all_gather, all_reduce, reduce_scatter +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use + +CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) + +SIZE = 8 + + +def check_all_gather(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) + print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + op.wait() + print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_reduce_scatter(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) + print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + op.wait() + print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_all_reduce(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) + print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + op.wait() + print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_layer(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + assert dist.get_rank() == gpc.get_global_rank() + print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) + + check_all_gather() + check_reduce_scatter() + check_all_reduce() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm(): + world_size = 4 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_comm() diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_comm/test_object_list_p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..701e3e8ade797e01d4ecb0cfbe46123fb372c8ae --- /dev/null +++ b/tests/test_comm/test_object_list_p2p.py @@ -0,0 +1,105 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use + +CONFIG = dict(parallel=dict(pipeline=2)) +torch.manual_seed(123) +LIST_LENGTH = 3 +TENSOR_SIZE = torch.Size((3, 3)) +TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)] +data = torch.rand(3, 3) +data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] +grad = torch.rand(3, 3) +grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] + + +def check_send_recv_forward(): + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + data_to_send = data.to(device) + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + send_forward(data_to_send) + send_forward(data_list_to_send) + else: + device = torch.device('cuda:1') + data_recv = recv_forward(TENSOR_SIZE) + data_list_recv = recv_forward(TENSOR_SIZE_LIST) + data_to_check = data.to(device) + assert data_recv.equal(data_to_check) + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + assert data_recv.equal(data_to_check) + + +def check_send_recv_backward(): + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + grad_recv = recv_backward(TENSOR_SIZE) + grad_list_recv = recv_backward(TENSOR_SIZE_LIST) + grad_to_check = grad.to(device) + assert grad_recv.equal(grad_to_check) + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_to_send = grad.to(device) + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + send_backward(grad_to_send) + send_backward(grad_list_to_send) + + +def check_send_recv_forward_backward(): + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + grad_list_recv = send_forward_recv_backward(data_list_to_send, TENSOR_SIZE_LIST) + + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + data_list_recv = send_backward_recv_forward(grad_list_to_send, TENSOR_SIZE_LIST) + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + assert data_recv.equal(data_to_check) + + +def check_layer(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_send_recv_forward() + check_send_recv_backward() + check_send_recv_forward_backward() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + world_size = 2 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_object_list_p2p() diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c639ac9f8ef3c01be4b9dc86029e2fd1dc5d582f --- /dev/null +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -0,0 +1,132 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group +from colossalai.context import ParallelMode, Initializer_Pipeline +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.logging import disable_existing_loggers + +disable_existing_loggers() + +# config +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=4)) +torch.manual_seed(123) +use_scatter_gather_tensors = False + +# data +torch.manual_seed(123) +LIST_LENGTH = 3 +TENSOR_SIZE = torch.Size((3, 3)) +TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)] +data = torch.rand(3, 3) +data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] +grad = torch.rand(3, 3) +grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] + + +def check_send_recv_forward(): + disable_existing_loggers() + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if local_rank == 0: + device = torch.device('cuda:0') + data_to_send = data.to(device) + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + + send_forward(data_to_send, scatter_gather_tensors=use_scatter_gather_tensors) + send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) + + elif local_rank == 1: + device = torch.device('cuda:1') + + data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) + data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) + + data_to_check = data.to(device) + + assert data_recv.equal(data_to_check) + + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + data_recv = data_recv.to(device) + assert data_recv.equal(data_to_check) + + +def check_send_recv_backward(): + disable_existing_loggers() + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + grad_recv = recv_backward(TENSOR_SIZE) + grad_list_recv = recv_backward(TENSOR_SIZE_LIST) + + grad_to_check = grad.to(device) + grad_recv = grad_recv[0].to(device) + + assert grad_recv.equal(grad_to_check) + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_recv = grad_recv.to(device) + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_to_send = grad.to(device) + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + send_backward(grad_to_send) + send_backward(grad_list_to_send) + + +def check_small_pipeline(): + disable_existing_loggers() + # make sure the rank is 4 + assert gpc.world_size == 4, "make sure to set world size to 4 to start the training process" + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + if local_rank == 0: + obj = [1, torch.randn(2, 2).cuda(), None] + send_forward(obj) + elif local_rank == 1: + obj = recv_forward() + send_forward(obj) + elif local_rank == 2: + obj = recv_forward() + send_forward(obj) + elif local_rank == 3: + obj = recv_forward() + else: + pass + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # check_send_recv_forward() + check_small_pipeline() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + disable_existing_loggers() + run_func = partial(check_layer, world_size=world_size, port=free_port()) + disable_existing_loggers() + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + disable_existing_loggers() + test_object_list_p2p() diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py new file mode 100644 index 0000000000000000000000000000000000000000..08ca108281b9c0700fef4ecb2c14416ccbabfd9f --- /dev/null +++ b/tests/test_config/sample_config.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +train_data = dict( + dataset=dict( + type='CIFAR10Dataset', + root='/path/to/data', + download=True, + transform_pipeline=[ + dict(type='RandomResizedCrop', size=224), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ] + ), + dataloader=dict( + batch_size=64, + pin_memory=True, + num_workers=4, + sampler=dict( + type='DataParallelSampler', + shuffle=True, + ) + ) +) diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py new file mode 100644 index 0000000000000000000000000000000000000000..550af2a4ae81656b4da0c159ce5a04bfdbb891cc --- /dev/null +++ b/tests/test_config/test_load_config.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from pathlib import Path + +import pytest + +from colossalai.context.config import Config + + +@pytest.mark.cpu +def test_load_config(): + filename = Path(__file__).parent.joinpath('sample_config.py') + config = Config.from_file(filename) + + assert config.train_data, 'cannot access train data as attribute' + assert config.train_data.dataset, 'cannot access grandchild attribute' + assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ + f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' diff --git a/tests/test_context/configs/parallel_2d_init.py b/tests/test_context/configs/parallel_2d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..6af884450ad0fee42d86fd1ad7ee950d576dd7da --- /dev/null +++ b/tests/test_context/configs/parallel_2d_init.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict( + pipeline=dict(size=2), + tensor=dict( + size=4, + mode='2d' + ) +) diff --git a/tests/test_context/configs/parallel_2p5d_init.py b/tests/test_context/configs/parallel_2p5d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d896d383e26d1530bd05d4127dfdafec57d826 --- /dev/null +++ b/tests/test_context/configs/parallel_2p5d_init.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict( + pipeline=dict(size=2), + tensor=dict( + size=8, + depth=2, + mode='2.5d' + ) +) diff --git a/tests/test_context/configs/parallel_3d_init.py b/tests/test_context/configs/parallel_3d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec724f8bb4f2513457568eaeb221727e4da2ff1 --- /dev/null +++ b/tests/test_context/configs/parallel_3d_init.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict( + pipeline=dict(size=2), + tensor=dict( + size=8, + mode='3d' + ) +) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..f311b1d2e7364007e7155602777a51c4c06910dd --- /dev/null +++ b/tests/test_context/test_hybrid_parallel.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial +from pathlib import Path +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai import launch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import free_port +from colossalai.context import reset_seeds +from colossalai.global_variables import tensor_parallel_env as tp_env +from colossalai.testing import rerun_if_address_is_in_use + +CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) + + +def check_data_parallel_rank(rank): + global_world_size = gpc.get_world_size(ParallelMode.GLOBAL) + mp_size = gpc.get_world_size(ParallelMode.MODEL) + num_dp_groups = global_world_size // mp_size + dp_local_rank = gpc.get_local_rank(ParallelMode.DATA) + + assert gpc.get_world_size(ParallelMode.DATA) == num_dp_groups + + for group_idx in range(num_dp_groups): + ranks_in_dp_group = range(group_idx * mp_size, (group_idx + 1) * mp_size) + if rank in ranks_in_dp_group: + assert dp_local_rank == group_idx + + +def check_pipeline_parallel_rank(rank): + mp_world_size = gpc.get_world_size(ParallelMode.MODEL) + tp_world_size = gpc.get_world_size(ParallelMode.TENSOR) + num_pipeline_stage = mp_world_size // tp_world_size + pipeline_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + for stage_idx in range(num_pipeline_stage): + ranks_in_current_stage = range(stage_idx * tp_world_size, (stage_idx + 1) * tp_world_size) + if rank in ranks_in_current_stage: + assert stage_idx == pipeline_local_rank + + +def check_model_parallel_rank(rank): + mp_size = gpc.get_world_size(ParallelMode.MODEL) + rank_within_mp_group = rank % mp_size + mp_local_rank = gpc.get_local_rank(ParallelMode.MODEL) + assert rank_within_mp_group == mp_local_rank + + +def check_tensor_parallel_rank(rank): + if tp_env.mode == '2d': + check_2d_tensor_parallel_rank(rank) + elif tp_env == '2.5d': + check_2p5d_tensor_parallel_rank(rank) + elif tp_env == '3d': + check_3d_tensor_parallel_rank(rank) + + +def get_tp_info(): + global_world_size = gpc.get_world_size(ParallelMode.GLOBAL) + tp_world_size = gpc.get_world_size(ParallelMode.TENSOR) + num_tp_groups = global_world_size // tp_world_size + tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) + return tp_local_rank, tp_world_size, num_tp_groups + + +def check_2d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + col_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + row_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + + assert col_local_rank == tp_local_rank // tp_env.summa_dim + assert row_local_rank == tp_local_rank % tp_env.summa_dim + + +def check_2p5d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ) + + assert rp_rank == tp_local_rank % tp_env.summa_dim + assert cp_rank == tp_local_rank // tp_env.tesseract_dim + assert dp_rank == tp_local_rank // (tp_env.summa_dim**2) + assert xp_rank == tp_local_rank // tp_env.summa_dim + + +def check_3d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) + wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) + op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + + assert ip_rank == tp_local_rank % tp_env.depth_3d + assert wp_rank == tp_local_rank // tp_env.depth_3d + assert op_rank == tp_local_rank // (tp_env.depth_3d**2) + + +def init_context(config_path, rank, world_size, backend, port, host): + dist_args = dict(config=config_path, + rank=rank, + world_size=world_size, + backend=backend, + port=port, + host=host, + verbose=True) + launch(**dist_args) + + check_tensor_parallel_rank(rank) + check_data_parallel_rank(rank) + check_pipeline_parallel_rank(rank) + check_model_parallel_rank(rank) + gpc.destroy() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, backend, port_list, host): + for config_path, port in zip(CONFIG_PATH_LIST, port_list): + init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host) + reset_seeds() + + +@pytest.mark.cpu +@rerun_if_address_is_in_use() +def test_context(): + """ + As no computation or communication is done, we can run this test on CPU. + """ + world_size = 32 + port_list = [] + + for _ in range(len(CONFIG_PATH_LIST)): + while True: + port = free_port() + if port not in port_list: + port_list.append(port) + break + + test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost') + mp.spawn(test_fn, nprocs=world_size) + + +if __name__ == '__main__': + test_context() diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9ca61d9f1796353f48fd96ca2d8b1cbacfa17d --- /dev/null +++ b/tests/test_data/test_cifar10_dataset.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import pytest +from torchvision import transforms, datasets +from torch.utils.data import DataLoader + + +@pytest.mark.cpu +def test_cifar10_dataset(): + # build transform + transform_pipeline = [transforms.ToTensor()] + transform_pipeline = transforms.Compose(transform_pipeline) + + # build dataset + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2) + data_iter = iter(dataloader) + img, label = data_iter.next() + + +if __name__ == '__main__': + test_cifar10_dataset() diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..54fa44bdc0c20ecbd673adea7757020e2db5bb35 --- /dev/null +++ b/tests/test_data/test_data_parallel_sampler.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from functools import partial +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from torchvision import transforms, datasets +from colossalai.context import ParallelMode, Config +from colossalai.core import global_context as gpc +from colossalai.utils import get_dataloader, free_port +from colossalai.testing import rerun_if_address_is_in_use + +CONFIG = Config(dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, +)) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + colossalai.launch(**dist_args) + print('finished initialization') + + # build dataset + transform_pipeline = [transforms.ToTensor()] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + assert not torch.equal( + img, img_to_compare), 'Same image was distributed across ranks but expected it to be different' + torch.cuda.empty_cache() + + +@pytest.mark.cpu +@rerun_if_address_is_in_use() +def test_data_sampler(): + world_size = 4 + test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) + mp.spawn(test_func, nprocs=world_size) + + +if __name__ == '__main__': + test_data_sampler() diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..4d76e7f137f16227f982acc9305ff6c9b1c2ec6d --- /dev/null +++ b/tests/test_data/test_deterministic_dataloader.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from functools import partial +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torchvision import transforms, datasets + +import colossalai +from colossalai.context import ParallelMode, Config +from colossalai.core import global_context as gpc +from colossalai.utils import get_dataloader, free_port +from colossalai.testing import rerun_if_address_is_in_use +from torchvision import transforms + +CONFIG = Config( + dict( + train_data=dict( + dataset=dict( + type='CIFAR10', + root=Path(os.environ['DATA']), + train=True, + download=True, + ), + dataloader=dict(num_workers=2, batch_size=2, shuffle=True), + ), + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + )) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + colossalai.launch(**dist_args) + + # build dataset + transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + # this is without sampler + # this should be false if data parallel sampler to given to the dataloader + assert torch.equal(img, + img_to_compare), 'Same image was distributed across ranks and expected it to be the same' + torch.cuda.empty_cache() + + +@pytest.mark.cpu +@rerun_if_address_is_in_use() +def test_data_sampler(): + world_size = 4 + test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) + mp.spawn(test_func, nprocs=world_size) + + +if __name__ == '__main__': + test_data_sampler() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2390c928379e59337ae23c52db469b61b7fce4 --- /dev/null +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -0,0 +1,105 @@ +import os + +from functools import partial +from pathlib import Path + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.amp import AMP_TYPE +from colossalai.trainer import Trainer, hooks +from colossalai.context import ParallelMode +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import get_dataloader +from colossalai.pipeline.pipelinable import PipelinableContext +from torchvision.datasets import CIFAR10 +from torchvision import transforms + +BATCH_SIZE = 4 +NUM_EPOCHS = 60 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + logger = get_dist_logger() + + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + max_steps=2, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + world_size = 8 + run_func = partial(run_trainer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..480f158821f3c8d67db550059ad2a0c137b0a347 --- /dev/null +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py @@ -0,0 +1,111 @@ +import os + +from functools import partial +from pathlib import Path + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.amp import AMP_TYPE +from colossalai.trainer import Trainer, hooks +from colossalai.context import ParallelMode +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import get_dataloader +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.logging import disable_existing_loggers +from torchvision.datasets import CIFAR10 +from torchvision import transforms + +from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 + +disable_existing_loggers() +BATCH_SIZE = 4 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + max_steps=2, + epochs=NUM_EPOCHS, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + world_size = 2 + run_func = partial(run_trainer, world_size=world_size, port=free_port()) + disable_existing_loggers() + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py new file mode 100644 index 0000000000000000000000000000000000000000..2be962e1a2e554c81a18fdf4e486e6702e9bcc22 --- /dev/null +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -0,0 +1,96 @@ +import os +import random +from functools import partial +from typing import Callable, Type + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.parallel import ColoDDP, ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def init_ddp(module: torch.nn.Module) -> ColoDDP: + pg = ProcessGroup() + return ColoDDP(module, process_group=pg) + + +def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: + chunk_config, _ = search_chunk_configuration(module, 4, 1024) + chunk_manager = ChunkManager(chunk_config) + gemini_manager = GeminiManager('cuda', chunk_manager) + return ZeroDDP(module, gemini_manager) + + +class Net(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(3, 3, bias=False) + self.fc2 = torch.nn.Linear(3, 1, bias=False) + + def forward(self, x): + return self.fc2(self.fc1(x)) + + +def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): + with ColoInitContext(device=get_current_device()): + model = Net().cuda() + w1 = model.fc1.weight + w2 = model.fc2.weight + ddp_cls.set_params_to_ignore([w2]) + model = init_ddp_func(model) + x = torch.rand(2, 3, device=get_current_device()) + logits = model(x) + loss = torch.sum(logits) + model.backward(loss) + + if ddp_cls is ZeroDDP: + w1s_grad = w1 + else: + w1s_grad = w1.grad + + w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())] + dist.all_gather(w1_grads, w1s_grad) + assert torch.equal(w1_grads[0], w1_grads[1]) + w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())] + dist.all_gather(w2_grads, w2.grad) + assert not torch.equal(w2_grads[0], w2_grads[1]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + set_seed(dist.get_rank()) + run_fwd_bwd(ColoDDP, init_ddp) + run_fwd_bwd(ZeroDDP, init_ddpv2) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_ddp_ignore_params(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_ddp_ignore_params(2) diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..f229364c6eb14eb54f632b1a9fdf11ae4afcba04 --- /dev/null +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -0,0 +1,71 @@ +import copy + +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from functools import partial +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.parallel import ColoDDP +from collections import OrderedDict +from colossalai.tensor import ProcessGroup, ColoParameter + + +def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): + for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()): + assert k1 == k2 + + if t1.device != t2.device: + temp_t2 = t2.to(t1.device) + else: + temp_t2 = t2 + + assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2) + + +def init_ddp(module: torch.nn.Module) -> ColoDDP: + pg = ProcessGroup() + return ColoDDP(module, process_group=pg) + + +def run_ddp_state_dict(): + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + torch_model = model_builder().cuda() + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = init_ddp(model) + torch_state_dict = torch_model.state_dict() + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + model.load_state_dict(torch_state_dict) + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + + state_dict = model.state_dict() + check_state_dict_equal(torch_state_dict, state_dict) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_ddp_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_state_dict(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_state_dict(2) diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b302d99ffb12b7d73d39c664757ab2d88927841 --- /dev/null +++ b/tests/test_ddp/test_reducer.py @@ -0,0 +1,48 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from functools import partial +from colossalai.nn.parallel.reducer import Reducer +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group + +REDUCE_CNT = 0 + + +def check_eq(grad, grad_clone): + global REDUCE_CNT + print(f'Rank{dist.get_rank()} check {REDUCE_CNT}') + REDUCE_CNT += 1 + assert torch.allclose(grad, grad_clone) + + +def run_reducer(): + grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)] + grads_clone = [g.clone().detach() for g in grads] + for g in grads: + dist.all_reduce(g) + reducer = Reducer(bucket_size_mb=1) + for g, g_clone in zip(grads, grads_clone): + reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g)) + reducer.flush() + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_reducer() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_reducer(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_reducer(2) diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py new file mode 100644 index 0000000000000000000000000000000000000000..5b076fdf0a1d4809447c3b8ab9fab76a3be98791 --- /dev/null +++ b/tests/test_device/test_alpha_beta.py @@ -0,0 +1,14 @@ +import pytest + +from colossalai.device import profile_alpha_beta + + +@pytest.mark.skip(reason="Skip because assertion fails for CI devices") +def test_profile_alpha_beta(): + physical_devices = [0, 1, 2, 3] + (alpha, beta) = profile_alpha_beta(physical_devices) + assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..3be057b3a98bba70c3e162d7148fb9c66ee84792 --- /dev/null +++ b/tests/test_device/test_device_mesh.py @@ -0,0 +1,21 @@ +from colossalai.device.device_mesh import DeviceMesh +import torch + + +def test_device_mesh(): + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + assert device_mesh.convert_map[5] == [1, 1] + assert device_mesh.convert_map[11] == [2, 3] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] + assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + + +if __name__ == '__main__': + test_device_mesh() diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py new file mode 100644 index 0000000000000000000000000000000000000000..3172897fb5cda5a9cb665f465158339de1a0d6cf --- /dev/null +++ b/tests/test_device/test_init_logical_pg.py @@ -0,0 +1,49 @@ +import torch +from functools import partial +import pytest +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed import ReduceOp + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.device.device_mesh import DeviceMesh + + +def check_layer(rank, world_size, port): + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + assert rank == gpc.get_global_rank() + + tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda() + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} + logical_process_groups = device_mesh.process_groups_dict + + for mesh_dim, pgs in logical_pg_dict.items(): + for index, pg in enumerate(pgs): + if rank in pg: + tensor = torch.ones(4).cuda() + group = logical_process_groups[mesh_dim][index][1] + dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) + assert tensor.equal(tensor_to_check) + + gpc.destroy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_logical_pg(): + world_size = 4 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_logical_pg() diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5bd1e1602e4dd2f6627b1abcfa28ed2c1d12f5 --- /dev/null +++ b/tests/test_engine/test_engine.py @@ -0,0 +1,67 @@ +from functools import partial + +import colossalai +import pytest +import torch.multiprocessing as mp +from colossalai.amp import AMP_TYPE +from colossalai.core import global_context as gpc +from colossalai.utils import free_port +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import parameterize, rerun_if_address_is_in_use + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), + fp16=dict(mode=None), + clip_grad_norm=1.0) + + +@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']) +@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) +def run_train(model_name, amp_mode): + # FIXME: test bert + get_components_func = non_distributed_component_funcs.get_callable(model_name) + gpc.config.fp16['mode'] = amp_mode + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + model = model_builder(checkpoint=False) + engine, train_dataloader, *args = colossalai.initialize(model=model, + optimizer=optimizer_class(model.parameters(), lr=1e-3), + criterion=criterion, + train_dataloader=train_dataloader) + + try: + engine.train() + for data, label in train_dataloader: + engine.zero_grad() + data = data.cuda() + label = label.cuda() + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + else: + loss = engine(data, label) + engine.backward(loss) + engine.step() + break + except IndexError: + # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue + # the following check fails in apex + # if cached_x.grad_fn.next_functions[1][0].variable is not x: + pass + + +def run_engine(rank, world_size, port): + # init dist env + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_train() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_engine(): + world_size = 2 + run_func = partial(run_engine, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_engine() diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_engine/test_gradient_accumluation.py new file mode 100644 index 0000000000000000000000000000000000000000..7f5ee47be8e687f4c01ff4487b9505d3b4bb28ba --- /dev/null +++ b/tests/test_engine/test_gradient_accumluation.py @@ -0,0 +1,99 @@ +import os +from functools import partial +from pathlib import Path + +import colossalai +from colossalai.testing.utils import rerun_if_address_is_in_use +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import free_port, get_dataloader +from colossalai.testing import rerun_if_address_is_in_use +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +# Config +BATCH_SIZE = 2 +NUM_CLASSES = 10 + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), + clip_grad_norm=1.0, + gradient_accumulation=4) + + +def run_no_pipeline(rank, world_size, port): + + # init dist env + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model + model = resnet18(num_classes=10) + + # build dataloaders + train_dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ])) + train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=BATCH_SIZE, + pin_memory=True, + drop_last=True) + + # build optimizer + optimizer = Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + engine, train_dataloader, *args = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + logger = get_dist_logger() + rank = torch.distributed.get_rank() + param_track = [] + grad_track = [] + next(model.parameters()).retain_grad() + + engine.train() + step = 0 + for img, label in train_dataloader: + engine.zero_grad() + img = img.cuda() + label = label.cuda() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() + + # check + param_track.append(next(model.parameters())[0].clone()) + grad_track.append(next(model.parameters()).grad[0].clone()) + step += 1 + if step == CONFIG['gradient_accumulation']: + break + + assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations' + assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \ + 'param should be the same in the first few iterations and only changed in the last iteration' + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_engine(): + world_size = 4 + func = partial(run_no_pipeline, world_size=world_size, port=free_port()) + mp.spawn(func, nprocs=world_size) + + +if __name__ == '__main__': + test_engine() diff --git a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..773cf151d2e971ea28becc7f242d968d574addd0 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py @@ -0,0 +1,76 @@ +import copy + +import colossalai +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp +import torchvision.models as tm +from colossalai.core import global_context as gpc +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.passes.algorithms import solver_rotor +from colossalai.fx.passes.algorithms.operation import Sequence +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if is_compatible_with_meta(): + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + withcodegen = True +except: + from colossalai.fx.codegen import python_code_with_activation_checkpoint + withcodegen = False + + +def _run_C_solver_consistency_test(rank=0): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: + model = M() + data = torch.rand(128, 3, 224, 224, device='meta') + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x": data}) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + if is_compatible_with_meta(): + data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device) + MetaInfoProp(gm).run(data_meta) + + # python solver + gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True) + sequence_python: Sequence = copy.deepcopy(gm.__sequence__) + opt_python = copy.deepcopy(gm.__opttable__) + + # C solver + gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024) + sequence_C: Sequence = copy.deepcopy(gm.__sequence__) + opt_C = copy.deepcopy(gm.__opttable__) + + # make sure the opt_tables are the same + for m in range(len(opt_python)): + for d in range(1, len(opt_python[0])): + for i in range(len(opt_python[0]) - d): + assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \ + f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" + + sequence_python = sequence_python.list_operations() + sequence_C = sequence_C.list_operations() + + # make sure the sequences are the same + assert len(sequence_python) == len(sequence_C) and \ + all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)) + + gpc.destroy() + + +@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") +def test_C_solver_consistency(): + mp.spawn(_run_C_solver_consistency_test, nprocs=1) + + +if __name__ == '__main__': + _run_C_solver_consistency_test(rank=0) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py new file mode 100644 index 0000000000000000000000000000000000000000..9949d49c1e01fcec17c9182b88191318c04b3688 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -0,0 +1,140 @@ +import copy +import re +from typing import Callable + +import pytest +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from torch.fx import GraphModule + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if is_compatible_with_meta(): + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + +SOLVERS = [chen_greedy, solver_rotor] + + +def _is_activation_checkpoint_available(gm: GraphModule): + for n in gm.graph.nodes: + if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + return True + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _is_graph_linearized(gm: GraphModule): + code = gm.code + # find patterns like r' return output_1, output_2', which is not expected on a linearized graph + pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + if pattern.findall(code): + return False + else: + return True + + +def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module]): + criterion = torch.nn.MSELoss() + m.cuda() + data = torch.rand(2, 3, 32, 32).cuda() + label = torch.rand(2, 5).cuda() + loss = criterion(m(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + + +def _run_ckpt_solver(rank): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + MODEL_LIST = [tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer(trace_act_ckpt=False) + + data = torch.rand(8, 3, 224, 224, device='meta') + for solver in SOLVERS: + for model_cls in MODEL_LIST: + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda()) + codegen = ActivationCheckpointCodeGen() + gm.graph.set_codegen(codegen) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + else: + gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() + + +@pytest.mark.skip("TODO(super-dainiu): refactor all tests.") +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_ckpt_solver(): + mp.spawn(_run_ckpt_solver, nprocs=1) + + +def _run_ckpt_solver_torch11(rank): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + MODEL_LIST = [tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer(trace_act_ckpt=False) + + data = torch.rand(8, 3, 32, 32, device='meta') + for solver in SOLVERS: + for model_cls in MODEL_LIST: + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + MetaInfoProp(gm).run(data) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True) + else: + gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +def test_ckpt_solver_torch11(): + mp.spawn(_run_ckpt_solver_torch11, nprocs=1) + + +if __name__ == '__main__': + _run_ckpt_solver(rank=0) + test_ckpt_solver() + test_ckpt_solver_torch11() diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..a803f8c07277daf7466eacbab7a307cdab5c5636 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -0,0 +1,138 @@ +import pytest +import torch +import torchvision.models as tm +from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.algorithms import linearize, solver_rotor +from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp + +if is_compatible_with_meta(): + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +@pytest.mark.skip(reason='TODO: modify the logger') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") +def test_linearize(): + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + tracer = ColoTracer() + for M, budgets in MODEL_DICT.items(): + for budget in budgets: + model = M() + graph = tracer.trace(model) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu')) + node_list = linearize(gm) + gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) + op_list = gm.__sequence__.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + op_list = op_list[:op_list.index(loss_op)] + in_ckpt = False + ckpt_idx = 0 + for idx, op in enumerate(op_list): + if in_ckpt: + if isinstance(op, ForwardNograd): + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + + continue + + if isinstance(op, ForwardEnable): + for n in node_list[idx]: + assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" + in_ckpt = False + + ckpt_idx += 1 + continue + + if isinstance(op, ForwardCheck): + ckpt_idx += 1 + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + + continue + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + + del model + del gm + del node_list + + +@pytest.mark.skip(reason="torch11 meta tensor not implemented") +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") +def test_linearize_torch11(): + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + tracer = ColoTracer() + for M, budgets in MODEL_DICT.items(): + for budget in budgets: + model = M() + graph = tracer.trace(model) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + node_list = linearize(gm) + gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) + op_list = gm.__sequence__.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + op_list = op_list[:op_list.index(loss_op)] + in_ckpt = False + ckpt_idx = 0 + for idx, op in enumerate(op_list): + if in_ckpt: + if isinstance(op, ForwardNograd): + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + if isinstance(op, ForwardEnable): + for n in node_list[idx]: + assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" + in_ckpt = False + + ckpt_idx += 1 + continue + + if isinstance(op, ForwardCheck): + ckpt_idx += 1 + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + del model + del gm + del node_list + + +if __name__ == "__main__": + test_linearize() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..83df1bb5e69c5d072e34e4d2c5e30fcfa82a1796 --- /dev/null +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -0,0 +1,182 @@ +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.fx import GraphModule +from torch.utils.checkpoint import checkpoint + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.utils import free_port + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MLP(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear1(x), self.linear2(x) + + +class relu(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(x) + + +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mlp1 = MLP() + self.relu = relu() + self.linear2 = torch.nn.Linear(4, 4) + + def ckpt2(self, x): + return F.relu(x, inplace=True) + + def ckpt3(self, x, y): + return self.linear2(x) + self.linear2(y) + + def forward(self, x, y): + y1, y2 = checkpoint(self.mlp1, x) + y3 = checkpoint(self.relu, x) + + y4 = checkpoint(self.ckpt2, y) + y5 = checkpoint(self.ckpt3, y, y4) + y6 = self.linear2(y4) + return y1 + y2 + y3 + y4 + y5 + y6 + + +def _run_act_ckpt_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + data2 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + data2 = data2.to(device="cuda") + + non_fx_out = model(data1, data2) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # check ops are annotated with ckpt + # also annotate the selected node for offloading + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] + for node in graph.nodes: + if node.name in ckpt_nodes: + assert 'activation_checkpoint' in node.meta + + # annotate the selected node for offload + if node.name in offload_starts: + node.meta['activation_offload'] = True + + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + # the offload option is correct + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1, data2) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_act_ckpt_codegen, nprocs=1) + + +def _run_act_ckpt_python_code_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + data2 = torch.rand(4, 4) + + # copy model to cuda + data1 = data1.to(device="cuda") + data2 = data2.to(device="cuda") + + non_fx_out = model(data1, data2) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # check ops are annotated with ckpt + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] + for node in graph.nodes: + if node.name in ckpt_nodes: + assert 'activation_checkpoint' in node.meta + + # annotate the selected node for offload + if node.name in offload_starts: + node.meta['activation_offload'] = True + + gm = ColoGraphModule(model, graph) + gm.recompile() + # assert checkpoint function will be generated and + # the offload option is correct + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1, data2) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + + +if __name__ == '__main__': + _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3a49d181e1e6adb82230d0bad5823fb244fc60 --- /dev/null +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -0,0 +1,154 @@ +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.fx import GraphModule +from torch.utils.checkpoint import checkpoint + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.utils import free_port + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x)))))) + + +def _run_act_ckpt_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + + non_fx_out = model(data1) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate nested checkpoint + for node in graph.nodes: + if node.name == "linear1": + node.meta['activation_checkpoint'] = [0, 0, 0] + continue + if node.name == "linear2": + node.meta['activation_checkpoint'] = [0, 0, None] + if node.name == "linear3": + node.meta['activation_checkpoint'] = [0, 0, 1] + if node.name == "linear4": + node.meta['activation_checkpoint'] = [0, 1, None] + if node.name == "linear5": + node.meta['activation_checkpoint'] = 1 + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_act_ckpt_codegen, nprocs=1) + + +def _run_act_ckpt_python_code_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + + non_fx_out = model(data1) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate nested checkpoint + for node in graph.nodes: + if node.name == "linear1": + node.meta['activation_checkpoint'] = [0, 0, 0] + continue + if node.name == "linear2": + node.meta['activation_checkpoint'] = [0, 0, None] + if node.name == "linear3": + node.meta['activation_checkpoint'] = [0, 0, 1] + if node.name == "linear4": + node.meta['activation_checkpoint'] = [0, 1, None] + if node.name == "linear5": + node.meta['activation_checkpoint'] = 1 + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + + +if __name__ == '__main__': + _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..5d090066c76310bbb8ff175bbb18e4a2b5485c8e --- /dev/null +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -0,0 +1,179 @@ +import copy + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.fx import GraphModule + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.utils import free_port + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyNet(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear0 = torch.nn.Linear(4, 4) + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear0(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + return x + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): + + # test forward + non_fx_out = model(data) + fx_out = gm(data) + assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" + + # test barckward + loss0 = non_fx_out.sum() + loss0.backward() + loss1 = fx_out.sum() + loss1.backward() + assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + + +def _run_offload_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear1": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear2": + node.meta['activation_offload'] = [1, True, True] + if node.name == "linear4": + node.meta['activation_offload'] = [2, False, True] + if node.name == "linear5": + node.meta['activation_checkpoint'] = [0] + node.meta['activation_offload'] = True + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_offload_codegen, nprocs=1) + + +def _run_offload_codegen_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear1": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear2": + node.meta['activation_offload'] = [1, True, True] + if node.name == "linear4": + node.meta['activation_offload'] = [2, False, True] + if node.name == "linear5": + node.meta['activation_checkpoint'] = [0] + node.meta['activation_offload'] = True + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_offload_codegen_torch11, nprocs=1) + + +if __name__ == "__main__": + _run_offload_codegen(0) diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb6cf86466cf4592c4b1f8868ef453dcff93cec --- /dev/null +++ b/tests/test_fx/test_coloproxy.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from torch.fx import GraphModule +import pytest + + +class Conv1D(nn.Module): + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.shape[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def test_coloproxy(): + + tracer = ColoTracer() + model = Conv1D(3, 3) + input_sample = {'x': torch.rand(3, 3).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + node = list(gm.graph.nodes)[0] + + proxy = ColoProxy(node=node, tracer=tracer) + proxy.meta_data = torch.empty(4, 2, device='meta') + + assert len(proxy) == 4 + assert proxy.shape[0] == 4 and proxy.shape[1] == 2 + assert proxy.dim() == 2 + assert proxy.dtype == torch.float32 + assert proxy.size(0) == 4 + + +if __name__ == '__main__': + test_coloproxy() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py new file mode 100644 index 0000000000000000000000000000000000000000..8825bbb461d66309b43ec83cd8dcf0da4edf14a3 --- /dev/null +++ b/tests/test_fx/test_comm_size_compute.py @@ -0,0 +1,54 @@ +import colossalai +import colossalai.nn as col_nn +import pytest +import torch +import torch.nn as nn +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.passes.utils import get_comm_size +from torch.fx import symbolic_trace + +is_compatible = is_compatible_with_meta() +if is_compatible: + from colossalai.fx.profiler import MetaTensor + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +def test_comm_size_compute(): + model = MLP(MODEL_DIM) + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') + gm = symbolic_trace(model) + if is_compatible: + input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device) + MetaInfoProp(gm).run(input_sample) + annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + submodule_list = list(split_model.children()) + comm_size = get_comm_size(submodule_list[0], submodule_list[1]) + # the shape of tensor send from partition 0 to partition 1 is (8, 16) + assert comm_size == 128 + + +if __name__ == '__main__': + test_comm_size_compute() diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..a21a351f8d777fb7b9c6fcbe64c7724f5f680767 --- /dev/null +++ b/tests/test_fx/test_complete_workflow.py @@ -0,0 +1,87 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai +from colossalai.fx import ColoTracer +from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass +from colossalai.tensor import ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.model.lazy_init_context import LazyInitContext + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.dropout = torch.nn.Dropout(0) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear1(x) + x = self.dropout(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +def run_workflow(world_size, dev): + # initailization + with LazyInitContext() as ctx: + model = MLP(16) + + for param in model.parameters(): + assert param.is_meta + + # tracing + tracer = ColoTracer() + graph = tracer.trace(model) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + + # annotate + annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup(tp_degree=world_size)) + annotated_gm.recompile() + + # materialization and sharding + ctx.lazy_init_parameters(annotated_gm, device=dev) + for param in model.parameters(): + assert not param.is_meta + + # # check sharding + assert list(model.linear1.weight.shape) == [16 // world_size, 16] + assert list(model.linear1.bias.shape) == [16 // world_size] + assert list(model.linear2.weight.shape) == [16, 16 // world_size] + + # test forward to make sure that IR transform will produce the same results + # like how ColoTensor would do it normally + data = torch.rand(4, 16, device=dev) + non_fx_out = model(data) + fx_out = annotated_gm(data) + assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}' + + +def run_dist(rank, world_size, dev, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_workflow(world_size, dev) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('dev', ['cuda', 'cpu']) +@rerun_if_address_is_in_use() +def test_complete_workflow(world_size, dev): + if dev == 'cpu' and world_size > 1: + return + run_func = partial(run_dist, world_size=world_size, dev=dev, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_complete_workflow(1, 'cuda') diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..fb33e58a778c6fd9d6fc61dc70f187c29232515d --- /dev/null +++ b/tests/test_fx/test_graph_manipulation.py @@ -0,0 +1,50 @@ +import colossalai +import torch +from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes +from colossalai.fx import ColoTracer +from torch.fx import GraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.linear5 = torch.nn.Linear(dim, dim) + + def forward(self, x): + l1 = self.linear1(x) + l2 = self.linear2(x) + l3 = self.linear3(l1) + l4 = self.linear4(l2) + l5 = self.linear5(l3) + return l4, l5 + + +def test_graph_manipulation(): + model = MLP(4) + tracer = ColoTracer() + graph = tracer.trace(model) + nodes = list(graph.nodes) + x, l1, l2, l3, l4, l5, output = nodes + + leaf_nodes = set(get_leaf(graph)) + top_nodes = set(get_top(graph)) + compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None} + assign_bfs_level_to_nodes(graph) + + assert leaf_nodes == set([l4, l5]) + assert top_nodes == set([l1, l2]) + for node in graph.nodes: + if node.op in ('placeholder', 'output'): + assert not hasattr(node, 'bfs_level') + else: + assert node.bfs_level == compare_dict[node] + + +if __name__ == '__main__': + test_graph_manipulation() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py new file mode 100644 index 0000000000000000000000000000000000000000..209ded89cfb9cccab3926851260cf858dea4e2f5 --- /dev/null +++ b/tests/test_fx/test_meta/test_aten.py @@ -0,0 +1,81 @@ +from typing import Any, Callable, Union + +import pytest +import torch +import torch.nn as nn +from colossalai.fx._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +aten = torch.ops.aten + +registered_meta = { + ('aten.convolution.default', True): [ # (aten ops, requires_backward) + (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), + (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), + (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4)), + (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4, 4)), + ], + ('aten.native_batch_norm.default', True): [ + (nn.BatchNorm1d(4), torch.rand(2, 4)), + (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), + (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), + ], + ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], + ('aten.avg_pool1d.default', True): [ + (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), + (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), + ], + ('aten.avg_pool2d.default', True): [ + (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + ], + ('aten.relu.default', True): [ + (nn.ReLU(), torch.rand(4, 3, 1, 2)), + (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), + (nn.SiLU(), torch.rand(4, 3, 1, 2)), + (nn.GELU(), torch.rand(4, 3, 1, 2)), + (nn.ELU(), torch.rand(4, 3, 1, 2)), + (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), + (nn.Tanh(), torch.rand(4, 3, 1, 2)), + (nn.Hardswish(), torch.rand(4, 3, 1, 2)), + ] +} + + +def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: + assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' + assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' + assert tensor.stride() == meta_tensor.stride( + ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + + +def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: + x.requires_grad = requires_backward + meta_x = MetaTensor(x) + x_out, meta_out = f(x), f(meta_x) + compare_all(x_out, meta_out) + if requires_backward: + x_out.sum().backward() + meta_out.sum().backward() + compare_all(x.grad, meta_x.grad) + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +def test_meta_aten(): + for (aten_op, requires_backward), v in registered_meta.items(): + for f, x in v: + run_and_compare(f, x, requires_backward) + + +if __name__ == '__main__': + test_meta_aten() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..351c02c5744a5e1732cd1ec00b9b37993f602c39 --- /dev/null +++ b/tests/test_fx/test_meta/test_backward.py @@ -0,0 +1,48 @@ +import pytest +import timm.models as tmm +import torch +import torchvision.models as tm +from colossalai.fx._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +tm_models = [ + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, +] + +tmm_models = [ + tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224 +] + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +def test_torchvision_models(): + for m in tm_models: + model = m() + data = torch.rand(100000, 3, 224, 224, device='meta') + model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +def test_timm_models(): + for m in tmm_models: + model = m() + data = torch.rand(100000, 3, 224, 224, device='meta') + model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + + +if __name__ == '__main__': + test_torchvision_models() + test_timm_models() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..404b6d27d2d403177c5a6f053d214c358bd0aa2a --- /dev/null +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -0,0 +1,48 @@ +import pytest +import timm.models as tmm +import torch +import torchvision.models as tm +from colossalai.fx._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from colossalai.fx import meta_trace + +tm_models = [ + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, +] + +tmm_models = [ + tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224 +] + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +def test_torchvision_models_trace(): + for m in tm_models: + model = m() + data = torch.rand(1000, 3, 224, 224, device='meta') + graph = meta_trace(model, torch.device('cpu'), data) + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +def test_timm_models_trace(): + for m in tmm_models: + model = m() + data = torch.rand(1000, 3, 224, 224, device='meta') + graph = meta_trace(model, torch.device('cpu'), data) + + +if __name__ == '__main__': + test_torchvision_models_trace() + test_timm_models_trace() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..6fac180d8ba290b382e6780d719e91f930c30544 --- /dev/null +++ b/tests/test_fx/test_meta_info_prop.py @@ -0,0 +1,37 @@ +import torch +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from torch.fx import symbolic_trace + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +BATCH_SIZE = 2 +DIM_IN = 4 +DIM_OUT = 16 + + +def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): + assert meta_info_spec.shape == orig_tensor.shape + assert meta_info_spec.dtype == orig_tensor.dtype + assert meta_info_spec.stride == orig_tensor.stride() + assert meta_info_spec.numel == orig_tensor.numel() + + +def test_meta_info_prop(): + model = torch.nn.Linear(DIM_IN, DIM_OUT) + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') + if is_compatible_with_meta(): + input_sample = MetaTensor(input_sample, fake_device='cpu') + orig_output = model(input_sample) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + for node in gm.graph.nodes: + if node.op == 'placeholder': + meta_check(node.meta['tensor_meta'], input_sample) + if node.op == 'output': + meta_check(node.meta['tensor_meta'], orig_output) + + +if __name__ == '__main__': + test_meta_info_prop() diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..8963ba29cb03914cc02cc69dde5a56e3abf66039 --- /dev/null +++ b/tests/test_fx/test_parallel_1d.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers +from colossalai.initialize import launch +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from torch.fx import symbolic_trace +from colossalai.fx.passes import column_shard_linear_pass + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2))) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + input_tensor = torch.rand(2, 16).cuda() + model = MLP(16).cuda() + symbolic_traced = symbolic_trace(model) + output = model(input_tensor) + splitted_gm = column_shard_linear_pass(symbolic_traced) + new_output = splitted_gm(input_tensor) + + assert output.equal(new_output) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_1d(): + world_size = 2 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_1d() diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3afc6c97e2bb69778b5fce5668861cba3d3582eb --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -0,0 +1,69 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import inspect +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + + +def split_model_and_compare_output(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(**kwargs) + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(**kwargs) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + + # get output tensor from HFOutput datastructure + if 'logits' in output: + output_to_compare = output['logits'] + elif 'prediction_logits' in output: + output_to_compare = output['prediction_logits'] + else: + output_to_compare = output['last_hidden_state'] + + # compare output + if isinstance(output_part1, torch.Tensor): + assert output_to_compare.equal(output_part1) + elif isinstance(output_part1, (tuple, list)): + assert output_to_compare.equal(output_part1[0]) + else: + assert False diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef861bdefbe105299e31e8f06f8d1eb464fb3c2 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -0,0 +1,40 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_single_sentence_albert(): + MODEL_LIST = [ + transformers.AlbertModel, + transformers.AlbertForPreTraining, + transformers.AlbertForMaskedLM, + transformers.AlbertForSequenceClassification, + transformers.AlbertForTokenClassification, + ] + + config = transformers.AlbertConfig(vocab_size=100, + embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_single_sentence_albert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a7550413fac8f6514e1428b8d7b3999608d84bbc --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -0,0 +1,40 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_single_sentence_bert(): + MODEL_LIST = [ + transformers.BertModel, + transformers.BertForPreTraining, + transformers.BertLMHeadModel, + transformers.BertForMaskedLM, + transformers.BertForSequenceClassification, + transformers.BertForTokenClassification, + ] + + config = transformers.BertConfig(vocab_size=100, + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_single_sentence_bert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..6181c5c0706a97d895fb4d51d012809a2a5336a4 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -0,0 +1,36 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 64 +SEQ_LENGHT = 16 +NUM_EPOCHS = 2 +NUM_CHUNKS = 1 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_gpt(): + MODEL_LIST = [ + transformers.GPT2Model, + transformers.GPT2LMHeadModel, + transformers.GPT2DoubleHeadsModel, + transformers.GPT2ForTokenClassification, + # transformers.GPT2ForSequenceClassification, # not supported yet + ] + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_gpt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9b36be82bd9e9fb29580b6fe6807dbb1eb032a --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -0,0 +1,31 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_opt(): + MODEL_LIST = [ + transformers.OPTModel, + transformers.OPTForCausalLM, + ] + + config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..16d0163746b318dec6be47529b9f9c55d23d8a7e --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -0,0 +1,43 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_t5(): + MODEL_LIST = [ + transformers.T5Model, + transformers.T5ForConditionalGeneration, + transformers.T5EncoderModel, + ] + + config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + return kwargs + + def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + + if isinstance(model, transformers.T5EncoderModel): + data_gen_func = data_gen_for_encoder_only + else: + data_gen_func = data_gen + + split_model_and_compare_output(model, data_gen_func) + + +if __name__ == '__main__': + test_t5() diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb1f6f4bb237db355e54d7b57730b3e9a6990a5 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -0,0 +1,48 @@ +import pytest +import timm.models as tm +import torch +from timm_utils import split_model_and_compare_output + + +@pytest.mark.skip('balance split v2 is not ready') +def test_timm_models_without_control_flow(): + + MODEL_LIST = [ + tm.resnest.resnest50d, + tm.beit.beit_base_patch16_224, + tm.cait.cait_s24_224, + tm.convmixer.convmixer_768_32, + tm.efficientnet.efficientnetv2_m, + tm.resmlp_12_224, + tm.vision_transformer.vit_base_patch16_224, + tm.deit_base_distilled_patch16_224, + ] + + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + model = model_cls() + split_model_and_compare_output(model, data) + + +@pytest.mark.skip('balance split v2 is not ready') +def test_timm_models_with_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST_WITH_CONTROL_FLOW = [ + tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224 + ] + + data = torch.rand(2, 3, 224, 224) + + meta_args = {'x': data.to('meta')} + + for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: + model = model_cls() + split_model_and_compare_output(model, data, meta_args) + + +if __name__ == '__main__': + test_timm_models_without_control_flow() + test_timm_models_with_control_flow() diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa870e5c7a659c5a9c09eae5a9a5d5ec3b77220f --- /dev/null +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -0,0 +1,51 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import inspect +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) +torch.backends.cudnn.deterministic = True + + +def split_model_and_compare_output(model, data, meta_args=None): + model.eval() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(data) + + # tracing model + tracer = ColoTracer() + try: + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(data) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + assert output.equal(output_part1) diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py new file mode 100644 index 0000000000000000000000000000000000000000..75c74870523cfd263f14be9935f696a212f9f788 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -0,0 +1,43 @@ +import pytest +import torch +import transformers +from topo_utils import split_model_and_get_DAG, check_topo, MLP + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + +def test_opt(): + MODEL_LIST = [ + MLP, + transformers.OPTModel, + ] + + CONFIGS = [ + {'dim': 10, 'layers': 12}, + transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), + ] + + def data_gen_MLP(): + x = torch.zeros((16, 10)) + kwargs = dict(x=x) + return kwargs + + def data_gen_OPT(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + DATAGEN = [ + data_gen_MLP, + data_gen_OPT, + ] + + for i, model_cls in enumerate(MODEL_LIST): + model = model_cls(config=CONFIGS[i]) + top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i]) + # print(f'{top_mod=}\n----\n{topo=}') + check_topo(top_mod, topo) + +if __name__ == '__main__': + test_opt() \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..55dd65201acd9ddd06d3becc84326ba97b1a7943 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -0,0 +1,92 @@ +import torch +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.middleware.adaptor import get_fx_topology +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + +class MLP(torch.nn.Module): + def __init__(self, config={}): + super().__init__() + dim = config['dim'] + layers = config['layers'] + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +def split_model_and_get_DAG(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model) + + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + + return top_module, split_submodules[0]._topo + +def check_input(top_module, input_partition: Partition): + partition_output = input_partition.get_output_vals() + arg_pos = 0 + for node in top_module.graph.nodes: + if node.op == 'placeholder': + cur_checkee = partition_output[arg_pos] + to_partition_and_offset = cur_checkee.get() + assert len(to_partition_and_offset) == len(node.users.keys()) + arg_pos += 1 + + assert arg_pos == len(partition_output) + +def check_submod(top_module, part_id, mid_partition: Partition): + partition_input = mid_partition.get_input_vals() + partition_output = mid_partition.get_output_vals() + + cnt = 1 + cur_node = None + for node in top_module.graph.nodes: + if node.name.startswith('submod'): + cnt += 1 + if cnt == part_id: + cur_node = node + break + + assert len(partition_input) == len(cur_node.args) + assert len(partition_output) == len(cur_node.users) + +def check_topo(top_module, topo: Topo): + input_partition = topo.get_input_partition() + mid_partitions = topo.get_mid_partitions() + + check_input(top_module, input_partition) + for part_id, submod in mid_partitions.items(): + check_submod(top_module, part_id, submod) + \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py new file mode 100644 index 0000000000000000000000000000000000000000..5d47be2c7bea01323cf1938cfa98ffe1175271de --- /dev/null +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -0,0 +1,66 @@ +import inspect +import random + +import numpy as np +import pytest +import torch +import torchvision +import torchvision.models as tm +from packaging import version +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) +torch.backends.cudnn.deterministic = True + + +@pytest.mark.skip('balance split v2 is not ready') +def test_torchvision_models(): + MODEL_LIST = [ + tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, + tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5 + ] + + if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) + + tracer = ColoTracer() + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + model = model_cls() + model.eval() + cpu_rng_state = torch.get_rng_state() + output = model(data) + graph = tracer.trace(root=model) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(data) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + assert output.equal(output_part1) + + +if __name__ == '__main__': + test_torchvision_models() diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py new file mode 100644 index 0000000000000000000000000000000000000000..de8a9402ba5679fb0aa6c0e0239148702a96e6dc --- /dev/null +++ b/tests/test_fx/test_pipeline_passes.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn +from torch.fx import symbolic_trace +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \ + uniform_split_pass, balanced_split_pass_v2 + +import pytest + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +def pipeline_pass_test_helper(model, data, pass_func): + origin_output = model(data) + symbolic_traced = symbolic_trace(model) + annotated_model = pass_func(symbolic_traced, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + output = split_model(data) + assert output.equal(origin_output) + + +def test_pipeline_passes(): + model = MLP(MODEL_DIM) + data = torch.rand(BATCH_SIZE, MODEL_DIM) + pipeline_pass_test_helper(model, data, balanced_split_pass) + pipeline_pass_test_helper(model, data, balanced_split_pass_v2) + pipeline_pass_test_helper(model, data, uniform_split_pass) + + +if __name__ == '__main__': + test_pipeline_passes() diff --git a/tests/test_fx/test_profiler/gpt_utils.py b/tests/test_fx/test_profiler/gpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aec32268484fbcc3b362207a02b8a022bf660642 --- /dev/null +++ b/tests/test_fx/test_profiler/gpt_utils.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..c717960181ad640e38bdbdd82a59a605ad925642 --- /dev/null +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -0,0 +1,182 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.fx +import torchvision.models as tm +from gpt_utils import gpt2_medium, gpt2_xl +from torch.fx import symbolic_trace + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +TM_BATCH_SIZE = 64 +GPT_BATCH_SIZE = 8 +NUM_STEPS = 5 + + +def extract_forward_mem(gm: torch.fx.GraphModule): + node_size = 0 + param_size = 0 + for node in gm.graph.nodes: + node_size += calculate_fwd_tmp(node) + node_size += calculate_fwd_out(node) + param_size = parameter_size(gm) + return (node_size + param_size) / 1024**2, param_size / 1024**2 + + +def extract_forward_flops(gm: torch.fx.GraphModule): + fwd_flop = 0 + bwd_flop = 0 + for node in gm.graph.nodes: + fwd_flop += node.meta.get('fwd_flop', 0) + bwd_flop += node.meta.get('bwd_flop', 0) + return fwd_flop, bwd_flop + + +def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'): + data = torch.rand(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return data, label + + +def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return input_ids, attention_mask + + +def run_tm_forward(gm: torch.fx.GraphModule): + torch.cuda.reset_peak_memory_stats() + forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.cuda() + param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.train() + for n in range(NUM_STEPS): + torch.cuda.reset_peak_memory_stats() + data, _ = gen_tm_data(TM_BATCH_SIZE, (3, 224, 224)) + + # If we need to dive deep into the memory usage by + # inspecting `saved_tensor_hooks` + + # ===================================================== + # fwd_mem = 0 + # cache = set() + # def pack(x): + # if isinstance(x, torch.Tensor): + # nonlocal fwd_mem, cache + # if x.data_ptr() not in cache: + # fwd_mem += activation_size(x) + # cache.add(x.data_ptr()) + # return x + # def unpack(x): + # return x + # + # with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + # output = gm(data) + # print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}') + # ===================================================== + + output = gm(data) + forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS + del output + return forward_mem, param_mem + + +def run_gpt_forward(gm: torch.fx.GraphModule): + torch.cuda.reset_peak_memory_stats() + forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.cuda() + param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + for n in range(NUM_STEPS): + torch.cuda.reset_peak_memory_stats() + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0') + + # If we need to dive deep into the memory usage by + # inspecting `saved_tensor_hooks` + + # ===================================================== + # fwd_mem = 0 + # cache = set() + # def pack(x): + # if isinstance(x, torch.Tensor): + # nonlocal fwd_mem, cache + # if x.data_ptr() not in cache: + # fwd_mem += activation_size(x) + # cache.add(x.data_ptr()) + # return x + # def unpack(x): + # return x + # + # with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + # output = gm(data, mask) + # print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}') + # ===================================================== + + output = gm(data, mask) + forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS + del output + return forward_mem, param_mem + + +@run_on_environment_flag(name='FX_PROFILER') +def test_meta_info_prop(): + for m in [ + tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, + tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base, + tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5, + tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5, + tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d, + tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32, + tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn + ]: + model = m().cuda() + model.train() + data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0') + gm = symbolic_trace(model) + interp = MetaInfoProp(gm) + interp.propagate(data) + gm.cpu() + + meta_forward_mem, meta_param_mem = extract_forward_mem(gm) + fwd_flop, bwd_flop = extract_forward_flops(gm) + concrete_forward_mem, concrete_param_mem = run_tm_forward(gm) + + print( + f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + ) + del model, gm + + +@run_on_environment_flag(name='FX_PROFILER') +def test_gpt_meta_info_prop(): + for m in [gpt2_medium]: + model = m().cuda() + model.train() + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta') + graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) + interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0')) + model.cpu() + + fwd_flop, bwd_flop = extract_forward_flops(gm) + + concrete_forward_mem, concrete_param_mem = run_gpt_forward(gm) + meta_forward_mem, meta_param_mem = extract_forward_mem(gm) + + print( + f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + ) + del model, gm + + +if __name__ == '__main__': + test_meta_info_prop() + test_gpt_meta_info_prop() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..a834951bb6954d19572df3aac69bfed7b77940c6 --- /dev/null +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.fx import GraphModule +from torch.utils.checkpoint import checkpoint + +from colossalai.fx import ColoTracer + + +class MLP(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +# Simple module for demonstration +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mlp_1 = MLP() + self.mlp_2 = MLP() + self.output = torch.nn.Linear(4, 4) + + def forward(self, x): + x = checkpoint(self.mlp_1, x) + x = checkpoint(self.mlp_2, x) + x = self.output(x) + return x + + +def test_activation_checkpoint_annotation(): + module = MyModule() + + # test tracing with activation checkpoint + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(module) + gm = GraphModule(module, graph) + + for node in gm.graph.nodes: + if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: + assert node.meta.get('activation_checkpoint', -1) == 0 + + for node in gm.graph.nodes: + if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: + assert node.meta.get('activation_checkpoint', -1) == 1 + + tracer = ColoTracer(trace_act_ckpt=False) + graph = tracer.trace(module) + gm = GraphModule(module, graph) + + for node in gm.graph.nodes: + assert not hasattr(node, 'activation_checkpoint') + + +if __name__ == '__main__': + test_activation_checkpoint_annotation() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py new file mode 100644 index 0000000000000000000000000000000000000000..afa30a217604352ac90b0a20e9ee76ea9d4496f8 --- /dev/null +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -0,0 +1,114 @@ +import torch + +from colossalai.fx import ColoGraphModule, ColoTracer + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def test_linear_module(): + model = LinearModel(3, 6) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + for node in node_list: + if node.op == 'output': + continue + assert hasattr(node, '_meta_data') + weight_node = node_list[1] + bias_node = node_list[2] + linear_node = node_list[3] + add_node = node_list[4] + assert weight_node._meta_data.shape == (6, 3) + assert bias_node._meta_data.shape == (6,) + assert linear_node._meta_data.shape == (3, 6) + assert add_node._meta_data.shape == (3, 6) + + +def test_conv_module(): + model = ConvModel(3, 6, 2) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + node_list = list(graph.nodes) + for node in node_list: + if node.op == 'output': + continue + assert hasattr(node, '_meta_data') + weight_node = node_list[1] + bias_node = node_list[2] + conv_node = node_list[3] + view_node = node_list[4] + add_node = node_list[5] + assert weight_node._meta_data.shape == (6, 3, 2, 2) + assert bias_node._meta_data.shape == (6,) + assert conv_node._meta_data.shape == (4, 6, 63, 63) + assert view_node._meta_data.shape == (6, 1, 1) + assert add_node._meta_data.shape == (4, 6, 63, 63) + + +if __name__ == '__main__': + test_linear_module() + test_conv_module() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..ed842cff2776c2523c8a8644491a67f57ba8f9ef --- /dev/null +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from torch.fx import GraphModule +from colossalai.fx import ColoTracer as Tracer + + +class ControlFlowModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x, y): + x1 = self.linear1(x) + y1 = self.linear2(y) + + if x1.dim() == 2: + return x1 + y1 + else: + return x1 - y1 + + +def test_control_flow(): + model = ControlFlowModel() + tracer = Tracer() + graph_branch_true = tracer.trace(model, + meta_args={ + 'x': torch.rand(4, 10, device='meta'), + 'y': torch.rand(4, 10, device='meta') + }) + graph_branch_false = tracer.trace(model, + meta_args={ + 'x': torch.rand(10, device='meta'), + 'y': torch.rand(4, 10, device='meta') + }) + + gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__) + gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__) + gm_branch_true.recompile() + gm_branch_false.recompile() + + # test the true branch + x = torch.rand(4, 10) + y = torch.rand(4, 10) + assert torch.all(model(x, y) == gm_branch_true(x, y)) + assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) + + # test the true branch + x = torch.rand(10) + y = torch.rand(4, 10) + assert torch.all(model(x, y) == gm_branch_false(x, y)) + assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) + + +if __name__ == '__main__': + test_control_flow() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..95670b85f3352522ef1c0d895622ab7ac0da1634 --- /dev/null +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -0,0 +1,48 @@ +import torch +from torch.nn import functional as F +from colossalai.fx.tracer.meta_patch import patched_function + + +def test_conv(): + # test F.conv_1d + data_1d = torch.rand(3, 16, 10) + weight_1d = torch.rand(3, 16, 3) + out_1d = F.conv1d(data_1d, weight_1d) + patched_out_1d = patched_function.torch_nn_functional_conv1d(data_1d, weight_1d) + assert out_1d.shape == patched_out_1d.shape + + # test F.conv_transpose1d + weight_1d = torch.transpose(weight_1d, 0, 1) + out_transpose_1d = F.conv_transpose1d(data_1d, weight_1d) + patched_out_transpose_1d = patched_function.torch_nn_functional_convtranspose1d(data_1d, weight_1d) + assert out_transpose_1d.shape == patched_out_transpose_1d.shape + + # test F.conv2d + data_2d = torch.rand(3, 16, 10, 10) + weight_2d = torch.rand(3, 16, 3, 3) + out_2d = F.conv2d(data_2d, weight_2d) + patched_out_2d = patched_function.torch_nn_functional_conv2d(data_2d, weight_2d) + assert out_2d.shape == patched_out_2d.shape + + # test F.conv_transpose2d + weight_2d = torch.transpose(weight_2d, 0, 1) + out_transpose_2d = F.conv_transpose2d(data_2d, weight_2d) + patched_out_transpose_2d = patched_function.torch_nn_functional_convtranspose2d(data_2d, weight_2d) + assert out_transpose_2d.shape == patched_out_transpose_2d.shape + + # test F.conv3d + data_3d = torch.rand(3, 16, 10, 10, 10) + weight_3d = torch.rand(3, 16, 3, 3, 3) + out_3d = F.conv3d(data_3d, weight_3d) + patched_out_3d = patched_function.torch_nn_functional_conv3d(data_3d, weight_3d) + assert out_3d.shape == patched_out_3d.shape + + # test F.conv_transpose3d + weight_3d = torch.transpose(weight_3d, 0, 1) + out_transpose_3d = F.conv_transpose3d(data_3d, weight_3d) + patched_out_transpose_3d = patched_function.torch_nn_functional_convtranspose3d(data_3d, weight_3d) + assert out_transpose_3d.shape == patched_out_transpose_3d.shape + + +if __name__ == '__main__': + test_conv() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d93fe0408d77d648000828f4b04bf2d29be2d46 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -0,0 +1,30 @@ +import torch +from numpy import isin +from torch.fx import GraphModule +from torch.utils._pytree import tree_flatten + +from colossalai.fx import symbolic_trace + + +def trace_model_and_compare_output(model, data_gen): + # must turn on eval mode to ensure the output is consistent + model.eval() + + try: + kwargs = data_gen() + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + gm = symbolic_trace(model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + + # run forward + inputs = data_gen() + non_fx_out = model(**inputs) + fx_out = gm(**inputs) + + # check output + for k in non_fx_out.keys(): + if torch.is_tensor(fx_out[k]): + assert torch.equal( + fx_out[k], non_fx_out[k] + ), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}' diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..9c36b0c9cc96631734074736bd268e7670a60701 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -0,0 +1,66 @@ +import pytest +import torch +import transformers +from hf_tracer_utils import trace_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def test_single_sentence_albert(): + MODEL_LIST = [ + transformers.AlbertModel, + transformers.AlbertForPreTraining, + transformers.AlbertForMaskedLM, + transformers.AlbertForSequenceClassification, + transformers.AlbertForTokenClassification, + ] + + config = transformers.AlbertConfig(embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +def test_multi_sentence_albert(): + config = transformers.AlbertConfig(hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + + def data_gen_for_qa(): + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + inputs = tokenizer(question, text, return_tensors="pt") + return inputs + + model = transformers.AlbertForQuestionAnswering(config) + trace_model_and_compare_output(model, data_gen_for_qa) + + def data_gen_for_mcq(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + model = transformers.AlbertForMultipleChoice(config) + trace_model_and_compare_output(model, data_gen_for_mcq) + + +if __name__ == '__main__': + test_single_sentence_albert() + test_multi_sentence_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..62273e2d51c9ed2535ce0a3c5ee8254cb57eb043 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -0,0 +1,69 @@ +import pytest +import torch +import transformers +from hf_tracer_utils import trace_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def test_single_sentence_bert(): + MODEL_LIST = [ + transformers.BertModel, + transformers.BertForPreTraining, + transformers.BertLMHeadModel, + transformers.BertForMaskedLM, + transformers.BertForSequenceClassification, + transformers.BertForTokenClassification, + ] + + config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +def test_multi_sentence_bert(): + config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + + def data_gen_for_next_sentence(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + next_sentence = "The sky is blue due to the shorter wavelength of blue light." + encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + return encoding + + model = transformers.BertForNextSentencePrediction(config) + trace_model_and_compare_output(model, data_gen_for_next_sentence) + + def data_gen_for_qa(): + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + inputs = tokenizer(question, text, return_tensors="pt") + return inputs + + model = transformers.BertForQuestionAnswering(config) + trace_model_and_compare_output(model, data_gen_for_qa) + + def data_gen_for_mcq(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + model = transformers.BertForMultipleChoice(config) + trace_model_and_compare_output(model, data_gen_for_mcq) + + +if __name__ == '__main__': + test_single_sentence_bert() + test_multi_sentence_bert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py new file mode 100644 index 0000000000000000000000000000000000000000..04e874becd0079b2c1a2779691926ae2b219e006 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -0,0 +1,114 @@ +import pytest +import torch +import transformers +from hf_tracer_utils import trace_model_and_compare_output + +from colossalai.fx import symbolic_trace + +try: + import diffusers + HAS_DIFFUSERS = True +except ImportError: + HAS_DIFFUSERS = False + +BATCH_SIZE = 2 +SEQ_LENGTH = 5 +HEIGHT = 224 +WIDTH = 224 +IN_CHANNELS = 3 +LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8) +TIME_STEP = 2 + + +@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") +def test_vae(): + MODEL_LIST = [ + diffusers.AutoencoderKL, + diffusers.VQModel, + ] + + for model_cls in MODEL_LIST: + model = model_cls() + sample = torch.zeros(LATENTS_SHAPE) + + gm = symbolic_trace(model) + + model.eval() + gm.eval() + + with torch.no_grad(): + fx_out = gm(sample) + non_fx_out = model(sample) + assert torch.allclose( + fx_out['sample'], + non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +def test_clip(): + MODEL_LIST = [ + transformers.CLIPModel, + transformers.CLIPTextModel, + transformers.CLIPVisionModel, + ] + + CONFIG_LIST = [ + transformers.CLIPConfig, + transformers.CLIPTextConfig, + transformers.CLIPVisionConfig, + ] + + def data_gen(): + if isinstance(model, transformers.CLIPModel): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) + kwargs = dict(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values) + elif isinstance(model, transformers.CLIPTextModel): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + elif isinstance(model, transformers.CLIPVisionModel): + pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) + kwargs = dict(pixel_values=pixel_values) + return kwargs + + for model_cls, config in zip(MODEL_LIST, CONFIG_LIST): + model = model_cls(config=config()) + trace_model_and_compare_output(model, data_gen) + + +@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") +@pytest.mark.skip(reason='cannot pass the test yet') +def test_unet(): + MODEL_LIST = [ + diffusers.UNet2DModel, + diffusers.UNet2DConditionModel, + ] + + for model_cls in MODEL_LIST: + model = model_cls() + sample = torch.zeros(LATENTS_SHAPE) + + gm = symbolic_trace(model) + + model.eval() + gm.eval() + + with torch.no_grad(): + fx_out = gm(sample, TIME_STEP) + non_fx_out = model(sample, TIME_STEP) + assert torch.allclose( + fx_out['sample'], + non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == "__main__": + test_vae() + test_clip() + + # skip because of failure + # test_unet() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4c9684dc4200156d178a7cb5b32e6e581c7ff9 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -0,0 +1,36 @@ +import pytest +import torch +import transformers +from hf_tracer_utils import trace_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGTH = 16 + + +# TODO: remove this skip once we handle the latest gpt model +@pytest.mark.skip +def test_gpt(): + MODEL_LIST = [ + transformers.GPT2Model, + transformers.GPT2LMHeadModel, + transformers.GPT2DoubleHeadsModel, + transformers.GPT2ForTokenClassification, + # transformers.GPT2ForSequenceClassification, # not supported yet + ] + + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_gpt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..06260176ec6f2d2d1e9df3de6af5c6e74710ee15 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -0,0 +1,30 @@ +import pytest +import torch +import transformers +from hf_tracer_utils import trace_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGTH = 16 + + +def test_opt(): + MODEL_LIST = [ + transformers.OPTModel, + transformers.OPTForCausalLM, + ] + + config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..71e782fddc7640e1f2449aa74d0c3956a84de568 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -0,0 +1,42 @@ +import pytest +import torch +import transformers +from hf_tracer_utils import trace_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGTH = 16 + + +def test_t5(): + MODEL_LIST = [ + transformers.T5Model, + transformers.T5ForConditionalGeneration, + transformers.T5EncoderModel, + ] + + config = transformers.T5Config(d_model=128, num_layers=2) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + return kwargs + + def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + + if isinstance(model, transformers.T5EncoderModel): + data_gen_func = data_gen_for_encoder_only + else: + data_gen_func = data_gen + + trace_model_and_compare_output(model, data_gen_func) + + +if __name__ == '__main__': + test_t5() diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py new file mode 100644 index 0000000000000000000000000000000000000000..94a93e16f3c76bc0d90f8c9b61707e681b1c3d58 --- /dev/null +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -0,0 +1,482 @@ +import torch +from colossalai.fx.tracer.meta_patch import patched_module + + +def _run(data, module, patch_fn): + try: + if isinstance(data, dict): + output = patch_fn(module, **data) + if isinstance(data, tuple) or isinstance(data, list): + output = patch_fn(module, *data) + else: + output = patch_fn(module, data) + return output + except Exception as e: + return e + + +def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape): + output = _run(data, module, patch_fn) + + if expect_exception: + assert isinstance(output, AssertionError) + else: + assert not isinstance(output, Exception) + if isinstance(output, tuple): + for item, shape in zip(output, output_shape): + assert item.is_meta + assert item.shape == shape + else: + assert output.is_meta + assert output.shape == output_shape + + +def test_linear(): + # test linear patch can produce the meta output with correct shape + data = torch.rand(2, 4, device='meta') + module = torch.nn.Linear(4, 2) + _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) + + # test if the linear patch can catch exception when dimension does not match + data = torch.rand(2, 2, device='meta') + _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) + + +def test_rnn(): + # test rnn patch can produce the meta output with correct shape + data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) + module = torch.nn.RNN(10, 20, 2) + output, hn = module(*data) + meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) + _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) + + # test if the rnn patch can catch exception when dimension does not match + data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) + module = torch.nn.RNN(10, 20, 2) + output, hn = module(*data) + meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) + _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) + + +def test_embedding(): + data = torch.rand(2, 4, device='meta') + + # test layernorm + ln = torch.nn.LayerNorm(4) + _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) + + # test group norm + gn = torch.nn.GroupNorm(4, num_channels=8) + _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) + + # test batch norm 1d + bn1d = torch.nn.BatchNorm1d(4) + data = torch.rand(2, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + # test batch norm 2d + bn2d = torch.nn.BatchNorm2d(4) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + # # test batch size 3d + bn3d = torch.nn.BatchNorm3d(4) + + data = torch.rand(1, 1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + +def test_conv1d(): + # test conv 1d + data = torch.rand(2, 3, 4) + + conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv1d = torch.nn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv2d(): + # test conv 2d + data = torch.rand(2, 3, 4, 4) + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv3d(): + # test conv 3d + data = torch.rand(2, 3, 4, 4, 4) + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv_transpose1d(): + # test conv transpose1d + data = torch.rand(2, 3, 4) + + convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv_transpose2d(): + # test conv transpose2d + data = torch.rand(2, 3, 4, 4) + + convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans2d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans2d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv_transpose3d(): + # test conv transpose2d + data = torch.rand(2, 3, 4, 4, 4) + + convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans3d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans3d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_pool1d(): + combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], + [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + data = torch.rand(2, 3, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.rand(2, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.rand(2, 3, 4, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +def test_pool2d(): + combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], + [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +def test_pool3d(): + combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], + [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +# adapative pooling is different from other pooling, so test it individually +def test_adaptive_pooling_1d(): + pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_1d + + data = torch.rand(3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +def test_adaptive_pooling_2d(): + pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_2d + + data = torch.rand(3, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + data = torch.rand(2, 3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + +def test_adaptive_pooling_3d(): + pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_3d + + data = torch.rand(3, 4, 5) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + data = torch.rand(2, 3, 4, 5) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5, 6) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py new file mode 100644 index 0000000000000000000000000000000000000000..4406f02db24be8604e353ffd38a295b4072fd446 --- /dev/null +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -0,0 +1,82 @@ +import torch +from colossalai.fx.tracer.meta_patch import patched_function +from functools import partial + + +def _run(data, patch_fn): + try: + output = patch_fn(data) + return output + except Exception as e: + return e + + +def _assert_output_shape(data, patch_fn, expect_exception, output_shape): + output = _run(data, patch_fn) + + if expect_exception: + assert isinstance(output, AssertionError) + else: + assert not isinstance(output, Exception) + assert output.is_meta + assert output.shape == output_shape + + +def test_repeat_interleave(): + patch_fn = patched_function.torch_repeat_interleave + + # examples from https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html + data = torch.tensor([1, 2, 3]) + materialized_output = torch.repeat_interleave(data, repeats=2) + repeat_interleave = partial(patch_fn, repeats=2) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=3, dim=1) + repeat_interleave = partial(patch_fn, repeats=3, dim=1) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1) + repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0) + repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=True, + output_shape=materialized_output.shape) + + +def test_torch_max(): + data = torch.rand(4, 3) + out = torch.max(data) + patched_out = patched_function.torch_max(data) + assert out.shape == patched_out.shape + + data = torch.rand(4, 3, 2) + out, idx = torch.max(data, dim=1) + patched_out, patched_idx = patched_function.torch_max(data, dim=1) + assert out.shape == patched_out.shape + assert idx.shape == patched_idx.shape + + data = torch.rand(4, 3, 2) + out, idx = torch.max(data, dim=1, keepdim=True) + patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True) + assert out.shape == patched_out.shape + assert idx.shape == patched_idx.shape diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..28ec3d82556ced033757aa49b82fa931beeab0c2 --- /dev/null +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -0,0 +1,73 @@ +import pytest +import timm.models as tm +import torch + +from colossalai.fx import symbolic_trace + + +def trace_and_compare(model_cls, data, meta_args=None): + # trace + model = model_cls() + + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() + + gm = symbolic_trace(model, meta_args=meta_args) + + # run forward + with torch.no_grad(): + fx_out = gm(data) + non_fx_out = model(data) + + # compare output + if isinstance(fx_out, tuple): + # some models produce tuple as output + for v1, v2 in zip(fx_out, non_fx_out): + assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}' + else: + assert torch.allclose( + fx_out, non_fx_out, + atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +def test_timm_models_without_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST = [ + tm.resnest.resnest50d, + tm.beit.beit_base_patch16_224, + tm.cait.cait_s24_224, + tm.convmixer.convmixer_768_32, + tm.efficientnet.efficientnetv2_m, + tm.resmlp_12_224, + tm.vision_transformer.vit_base_patch16_224, + tm.deit_base_distilled_patch16_224, + ] + + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + trace_and_compare(model_cls, data) + + +def test_timm_models_with_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST_WITH_CONTROL_FLOW = [ + tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224 + ] + + data = torch.rand(2, 3, 224, 224) + + meta_args = {'x': data.to('meta')} + + for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: + trace_and_compare(model_cls, data, meta_args) + + +if __name__ == '__main__': + test_timm_models_with_control_flow() + test_timm_models_without_control_flow() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py new file mode 100644 index 0000000000000000000000000000000000000000..b2fa8c6c0bbbf2eb375c423ae3fb8aae06d8146b --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py @@ -0,0 +1,145 @@ +import torch +from torchaudio_utils import trace_and_compare +from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN +from torchaudio.models.wavernn import MelResNet, UpsampleNetwork +import pytest + + +def test_wave2letter_waveform(): + batch_size = 2 + num_features = 1 + num_classes = 40 + input_length = 320 + + model = Wav2Letter(num_classes=num_classes, num_features=num_features) + + def data_gen(): + x = torch.rand(batch_size, num_features, input_length) + return dict(x=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_wave2letter_mfcc(): + batch_size = 2 + num_features = 13 + num_classes = 40 + input_length = 2 + + model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features) + + def data_gen(): + x = torch.rand(batch_size, num_features, input_length) + return dict(x=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_melresnet_waveform(): + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 128 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + + def data_gen(): + x = torch.rand(n_batch, n_freq, n_time) + return dict(specgram=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_upsample_network_waveform(): + upsample_scales = [5, 5, 8] + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 64 + n_res_block = 10 + n_hidden = 32 + kernel_size = 5 + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + + def data_gen(): + x = torch.rand(n_batch, n_freq, n_time) + return dict(specgram=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_wavernn_waveform(): + upsample_scales = [2, 2, 5] + n_rnn = 16 + n_fc = 16 + n_classes = 10 + hop_length = 20 + n_batch = 2 + n_time = 20 + n_freq = 10 + n_output = 16 + n_res_block = 3 + n_hidden = 16 + kernel_size = 5 + + model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, + n_output) + + def data_gen(): + x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) + mels = torch.rand(n_batch, 1, n_freq, n_time) + return dict(waveform=x, specgram=mels) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +def test_convtasnet_config(): + batch_size = 32 + num_frames = 800 + + model = ConvTasNet() + + def data_gen(): + tensor = torch.rand(batch_size, 1, num_frames) + return dict(input=tensor) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +def test_deepspeech(): + n_batch = 2 + n_feature = 1 + n_channel = 1 + n_class = 40 + n_time = 32 + + model = DeepSpeech(n_feature=n_feature, n_class=n_class) + + def data_gen(): + x = torch.rand(n_batch, n_channel, n_time, n_feature) + return dict(x=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +if __name__ == '__main__': + TEST_LIST = [ + test_wave2letter_waveform, + test_wave2letter_mfcc, + test_melresnet_waveform, + test_upsample_network_waveform, + test_wavernn_waveform, + test_convtasnet_config, + test_deepspeech, + ] + + for test_fn in TEST_LIST: + test_fn() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..2073c46897f4784749f5274c87f9ff5c0a88ddd5 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py @@ -0,0 +1,57 @@ +import torch +from torchaudio.models import Tacotron2 +from torchaudio_utils import trace_and_compare +import pytest + + +def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): + return Tacotron2( + mask_padding=False, + n_mels=n_mels, + n_symbol=20, + n_frames_per_step=1, + symbol_embedding_dim=32, + encoder_embedding_dim=32, + encoder_n_convolution=3, + encoder_kernel_size=5, + decoder_rnn_dim=32, + decoder_max_step=decoder_max_step, + decoder_dropout=0.1, + decoder_early_stopping=True, + attention_rnn_dim=32, + attention_hidden_dim=32, + attention_location_n_filter=32, + attention_location_kernel_size=31, + attention_dropout=0.1, + prenet_dim=32, + postnet_n_convolution=5, + postnet_kernel_size=5, + postnet_embedding_dim=512, + gate_threshold=gate_threshold, + ) + + +@pytest.mark.skip("Tracing failed") +def test_tacotron_model(): + n_mels = 80 + n_batch = 3 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels) + + def data_gen(): + text = torch.randint(0, 148, (n_batch, max_text_length)) + text_lengths = max_text_length * torch.ones((n_batch,)) + mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length) + mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) + return dict(tokens=text, + token_lengths=text_lengths, + mel_specgram=mel_specgram, + mel_specgram_lengths=mel_specgram_lengths) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +if __name__ == "__main__": + test_tacotron_model() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe24a8cd91fb845b9cbc6c0e8d2878d3c2a0d08 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py @@ -0,0 +1,67 @@ +import torch +from torchaudio_utils import trace_and_compare +from torchaudio.models import Emformer, Conformer +import pytest + + +def test_conformer(): + input_dim = 80 + batch_size = 10 + num_frames = 400 + num_heads = 4 + ffn_dim = 128 + num_layers = 4 + depthwise_conv_kernel_size = 31 + + model = Conformer( + input_dim=input_dim, + num_heads=num_heads, + ffn_dim=ffn_dim, + num_layers=num_layers, + depthwise_conv_kernel_size=depthwise_conv_kernel_size, + ) + + def data_gen(): + lengths = torch.randint(1, num_frames, (batch_size,)) + input = torch.rand(batch_size, int(lengths.max()), input_dim) + return dict(input=input, lengths=lengths) + + def kwargs_transform(data): + new_data = {} + + for k, v in data.items(): + new_data[f'{k}_1'] = v + return new_data + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform) + + +@pytest.mark.skip("Tracing failed") +def test_emformer(): + input_dim = 128 + batch_size = 10 + num_heads = 8 + ffn_dim = 256 + num_layers = 3 + segment_length = 4 + num_frames = 400 + right_context_length = 1 + + model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length) + + def data_gen(): + lengths = torch.randint(1, num_frames, (batch_size,)) + input = torch.rand(batch_size, num_frames, input_dim) + return dict(input=input, lengths=lengths) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +@pytest.mark.skip +def test_torchaudio_transformers(): + test_conformer() + test_emformer() + + +if __name__ == "__main__": + test_torchaudio_transformers() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..e8729b83fba035d5039339d67c948f02daec8f03 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py @@ -0,0 +1,50 @@ +import torch +from torchaudio.models.wav2vec2 import ( + hubert_base, + hubert_large, + hubert_xlarge, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, +) +from torchaudio_utils import trace_and_compare +import pytest + +MODEL_LIST = [ + hubert_base, + hubert_large, + hubert_xlarge, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, +] + + +def _smoke_test(model, device): + model = model.to(device=device) + + batch_size, num_frames = 3, 1024 + + def data_gen(): + waveforms = torch.randn(batch_size, num_frames, device=device) + lengths = torch.randint( + low=0, + high=num_frames, + size=[ + batch_size, + ], + device=device, + ) + return dict(waveforms=waveforms, lengths=lengths) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +@pytest.mark.skip("Tracing failed") +def test_wav2vec(): + for model_fn in MODEL_LIST: + _smoke_test(model_fn(), 'cpu') + + +if __name__ == "__main__": + test_wav2vec() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..702c5f8f6a247509d50caa1def12b27591f9a66a --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -0,0 +1,29 @@ +import torch + +from colossalai.fx import symbolic_trace + + +def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False): + data = data_gen() + concrete_args = data if need_concrete else {} + meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} + + model.eval() + + gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) + + with torch.no_grad(): + non_fx_out = model(**data) + + if kwargs_transform: + data = kwargs_transform(data) + + fx_out = gm(**data) + if isinstance(fx_out, tuple): + for non_fx, fx in zip(non_fx_out, fx_out): + assert torch.allclose( + non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out, non_fx_out, + atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe8a62e7c591ffc784a1fa3a9d3537413e6ff79 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -0,0 +1,86 @@ +import pytest +import torch + +from colossalai.fx import symbolic_trace + +try: + from torchrec.models import deepfm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NOT_TORCHREC = False +except ImportError: + NOT_TORCHREC = True + +BATCH = 2 +SHAPE = 10 + + +@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') +def test_torchrec_deepfm_models(): + MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] + + # Data Preparation + # EmbeddingBagCollection + eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) + eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + keys = ["f1", "f2"] + + # KeyedTensor + KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + # KeyedJaggedTensor + KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys, + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + + # Dense Features + features = torch.rand((BATCH, SHAPE)) + + for model_cls in MODEL_LIST: + # Initializing model + if model_cls == deepfm.DenseArch: + model = model_cls(SHAPE, SHAPE, SHAPE) + elif model_cls == deepfm.FMInteractionArch: + model = model_cls(SHAPE * 3, keys, SHAPE) + elif model_cls == deepfm.OverArch: + model = model_cls(SHAPE) + elif model_cls == deepfm.SimpleDeepFMNN: + model = model_cls(SHAPE, ebc, SHAPE, SHAPE) + elif model_cls == deepfm.SparseArch: + model = model_cls(ebc) + + # Setup GraphModule + gm = symbolic_trace(model) + + model.eval() + gm.eval() + + # Aligned Test + with torch.no_grad(): + if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch: + fx_out = gm(features) + non_fx_out = model(features) + elif model_cls == deepfm.FMInteractionArch: + fx_out = gm(features, KT) + non_fx_out = model(features, KT) + elif model_cls == deepfm.SimpleDeepFMNN: + fx_out = gm(features, KJT) + non_fx_out = model(features, KJT) + elif model_cls == deepfm.SparseArch: + fx_out = gm(KJT) + non_fx_out = model(KJT) + + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == "__main__": + test_torchrec_deepfm_models() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9fd8fe5982fd04747bc3049da3d95bf428fe8f --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -0,0 +1,112 @@ +import torch + +from colossalai.fx import symbolic_trace + +try: + from torchrec.models import dlrm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NOT_TORCHREC = False +except ImportError: + NOT_TORCHREC = True + +import pytest + +BATCH = 2 +SHAPE = 10 + + +@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') +def test_torchrec_dlrm_models(): + MODEL_LIST = [ + dlrm.DLRM, + dlrm.DenseArch, + dlrm.InteractionArch, + dlrm.InteractionV2Arch, + dlrm.OverArch, + dlrm.SparseArch, + # dlrm.DLRMV2 + ] + + # Data Preparation + # EmbeddingBagCollection + eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) + eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + keys = ["f1", "f2"] + + # KeyedTensor + KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + # KeyedJaggedTensor + KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys, + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + + # Dense Features + dense_features = torch.rand((BATCH, SHAPE)) + + # Sparse Features + sparse_features = torch.rand((BATCH, len(keys), SHAPE)) + + for model_cls in MODEL_LIST: + # Initializing model + if model_cls == dlrm.DLRM: + model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) + elif model_cls == dlrm.DenseArch: + model = model_cls(SHAPE, [SHAPE, SHAPE]) + elif model_cls == dlrm.InteractionArch: + model = model_cls(len(keys)) + elif model_cls == dlrm.InteractionV2Arch: + I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE]) + I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE]) + model = model_cls(len(keys), I1, I2) + elif model_cls == dlrm.OverArch: + model = model_cls(SHAPE, [5, 1]) + elif model_cls == dlrm.SparseArch: + model = model_cls(ebc) + elif model_cls == dlrm.DLRMV2: + # Currently DLRMV2 cannot be traced + model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE]) + + # Setup GraphModule + if model_cls == dlrm.InteractionV2Arch: + concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features} + gm = symbolic_trace(model, concrete_args=concrete_args) + else: + gm = symbolic_trace(model) + + model.eval() + gm.eval() + + # Aligned Test + with torch.no_grad(): + if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2: + fx_out = gm(dense_features, KJT) + non_fx_out = model(dense_features, KJT) + elif model_cls == dlrm.DenseArch: + fx_out = gm(dense_features) + non_fx_out = model(dense_features) + elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch: + fx_out = gm(dense_features, sparse_features) + non_fx_out = model(dense_features, sparse_features) + elif model_cls == dlrm.OverArch: + fx_out = gm(dense_features) + non_fx_out = model(dense_features) + elif model_cls == dlrm.SparseArch: + fx_out = gm(KJT) + non_fx_out = model(KJT) + + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == "__main__": + test_torchrec_dlrm_models() diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6c6ae1674badef21167bcf2914d91c96dae300 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -0,0 +1,45 @@ +import torch +import torchvision +import torchvision.models as tm +from packaging import version + +from colossalai.fx import symbolic_trace + + +def test_torchvision_models(): + MODEL_LIST = [ + tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, + tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0 + ] + + RANDOMIZED_MODELS = [tm.efficientnet_b0] + + if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) + RANDOMIZED_MODELS.append(tm.convnext_small) + + torch.backends.cudnn.deterministic = True + + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + if model_cls in RANDOMIZED_MODELS: + # remove the impact of randomicity + model = model_cls(stochastic_depth_prob=0) + else: + model = model_cls() + + gm = symbolic_trace(model) + + model.eval() + gm.eval() + + with torch.no_grad(): + fx_out = gm(data) + non_fx_out = model(data) + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == '__main__': + test_torchvision_models() diff --git a/tests/test_gemini/test_gemini_manager.py b/tests/test_gemini/test_gemini_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1f85eb751f439d080a073f81846422c193d21d2b --- /dev/null +++ b/tests/test_gemini/test_gemini_manager.py @@ -0,0 +1,73 @@ +import pytest +import torch + +from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor + + +@pytest.mark.dist +def test_gemini_manager(): + # reset the manager, in case that there exists memory information left + manager = StatefulTensor.GST_MGR + manager.reset() + + # occupation 8 + st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) + # occupation 60 + st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) + + # occupation 28 + t1 = torch.empty(7, device='cuda') + # occupation 12 + t2 = torch.empty(3, device='cpu') + st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) + st4 = StatefulTensor(None, TensorState.FREE) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 60 + assert manager.total_mem['cuda'] == 36 + assert manager.state_mem['cpu'][TensorState.HOLD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 + + st4.payload_reset(t2) + st3.payload_reset(t2) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 84 + assert manager.total_mem['cuda'] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD] == 72 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 + + st1.move_to(torch.device('cpu')) + st2.move_to(torch.device('cpu')) + st3.move_to(torch.device('cuda', 0)) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 80 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + + st1.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.HOLD_AFTER_BWD) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 + assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 + assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 + + +if __name__ == '__main__': + test_gemini_manager() diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py new file mode 100644 index 0000000000000000000000000000000000000000..daf386d6d6af90a568beaca8d0937dc032892e0e --- /dev/null +++ b/tests/test_gemini/test_param_op.py @@ -0,0 +1,80 @@ +import copy + +import torch + +from colossalai.gemini.paramhooks import BaseParamHookMgr +from tests.components_to_test.registry import non_distributed_component_funcs + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def run_model(model, inputs, label, criterion, use_param_hook=False): + if use_param_hook: + + class HooKWrapper: + + def __init__(self) -> None: + self.hook_triggered_times = 0 + + def wrapper_func(self): + + def hook(param, grad) -> torch.Tensor or None: + self.hook_triggered_times += 1 + return grad + + return hook + + hookwrapper = HooKWrapper() + param_list = [p for p in model.parameters()] + hook_mgr = BaseParamHookMgr(param_list) + hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) + + model.zero_grad(set_to_none=True) + + with torch.cuda.amp.autocast(): + if criterion: + y = model(inputs) + loss = criterion(y, label) + else: + loss = model(inputs, label) + loss = loss.float() + loss.backward() + + if use_param_hook: + hook_mgr.remove_hooks() + return hookwrapper.hook_triggered_times + + +def test_base_param_hook(): + test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] + # test_models = ['bert'] + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + torch.manual_seed(0) + model = model_builder(checkpoint=True).cuda() + model.train() + + for i, (inputs, label) in enumerate(train_dataloader): + if i > 0: + break + model_copy = copy.deepcopy(model) + + run_model(model, inputs.cuda(), label.cuda(), criterion, False) + ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) + + # Make sure param hook has only be fired once in case of parameter sharing + assert ret2 == len(list(model.parameters())) + + for p, p_copy in zip(model.parameters(), model_copy.parameters()): + assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" + + +if __name__ == '__main__': + test_base_param_hook() diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..294868458c47c8cd1c81cec315412f219e7deabc --- /dev/null +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -0,0 +1,52 @@ +from copy import deepcopy + +import numpy as np +import torch + +from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs + + +def test_runtime_mem_tracer(): + test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(checkpoint=False) + + model_bk = deepcopy(model) + runtime_mem_tracer = RuntimeMemTracer(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + data = data.cuda() + label = label.cuda() + + run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) + + for p1, p2 in zip(model_bk.parameters(), model.parameters()): + torch.allclose(p1.to(torch.half), p2) + + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 + print("cuda_non_model_data_list", len(cuda_non_model_data_list)) + print(non_model_data_list) + + cnt1 = 0 + for p in runtime_mem_tracer.parameters_in_runtime_order(): + cnt1 += 1 + cnt2 = 0 + for p in model.parameters(): + cnt2 += 1 + assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' + del model + + +if __name__ == '__main__': + test_runtime_mem_tracer() diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py new file mode 100644 index 0000000000000000000000000000000000000000..7d192fc631a6705d0d575bfaaa5b6ca80d91e2aa --- /dev/null +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -0,0 +1,72 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_tensor.common_utils import debug_print + +CUDA_MEM_0 = {False: 512, True: 1024} +CUDA_MEM_1 = {False: 0, True: 1024} +CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} + + +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_memory(keep_gathered, pin_memory): + pg = ProcessGroup() + + debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) + + params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] + config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} + + chunk_manager = ChunkManager(config) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == 0 + + for p in params: + chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory) + chunk_manager.close_all_groups() + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + chunks = chunk_manager.get_chunks(params) + + for chunk in chunks: + chunk_manager.access_chunk(chunk) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + + for chunk in chunks: + chunk_manager.release_chunk(chunk) + + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + for chunk in chunks: + chunk_manager.move_chunk(chunk, torch.device('cpu')) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_memory() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_chunk_manager(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_manager(2) diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py new file mode 100644 index 0000000000000000000000000000000000000000..48cae94e1be76908e2c0ab0867461053338417f3 --- /dev/null +++ b/tests/test_gemini/update/test_chunkv2.py @@ -0,0 +1,124 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini import TensorState +from colossalai.gemini.chunk import Chunk +from colossalai.tensor import ColoParameter +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device + + +def dist_sum(x): + temp = torch.tensor([x], device=get_current_device()) + dist.all_reduce(temp) + return temp.item() + + +def add_param(param_list, param_cp_list, *args, **kwargs): + param = ColoParameter(torch.randn(*args, **kwargs)) + param_list.append(param) + param_cp_list.append(param.clone()) + + +def check_euqal(param, param_cp): + if param.device != param_cp.device: + temp = param.data.to(param_cp.device) + else: + temp = param.data + return torch.equal(temp, param_cp.data) + + +@parameterize('init_device', [None, torch.device('cpu')]) +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory): + world_size = torch.distributed.get_world_size() + pg = ColoProcessGroup() + my_chunk = Chunk(chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory) + + param_list = [] + param_cp_list = [] + + add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 4, 4) + add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 1, 1, 5) + + for param in param_list: + my_chunk.append_tensor(param) + assert my_chunk.utilized_size == 597 + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + my_chunk.close_chunk() + + if keep_gathered is False: + assert my_chunk.cpu_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cpu' + assert my_chunk.can_move + my_chunk.shard_move(get_current_device()) + else: + assert my_chunk.cuda_global_chunk.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size + flag = my_chunk.has_inf_or_nan + assert not flag, "has_inf_or_nan is {}".format(flag) + + my_chunk.access_chunk() + assert my_chunk.device_type == 'cuda' + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3 + assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1 + assert not my_chunk.can_release + + for param in param_list: + my_chunk.tensor_trans_state(param, TensorState.COMPUTE) + my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + + assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 + assert my_chunk.can_reduce + my_chunk.reduce() + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + + if keep_gathered is False: + assert my_chunk.cuda_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cuda' + assert my_chunk.can_move + else: + assert my_chunk.cuda_global_chunk.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_basic() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@rerun_if_address_is_in_use() +def test_chunk_function(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_function(4) diff --git a/tests/test_gemini/update/test_convert_torch_module.py b/tests/test_gemini/update/test_convert_torch_module.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fd94b4044c4702a2ba0798b37489806b1ca485 --- /dev/null +++ b/tests/test_gemini/update/test_convert_torch_module.py @@ -0,0 +1,48 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.parallel.utils import convert_to_torch_module +from colossalai.tensor import ColoTensor +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('model_name', ['resnet18', 'bert']) +def run_convert_torch_module(model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, _, _, _, _ = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(checkpoint=False) + + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) + + pytorch_model = convert_to_torch_module(model) + + for n, p in pytorch_model.named_parameters(): + assert not isinstance(p, ColoTensor) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_convert_torch_module() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_convert_torch_module(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_convert_torch_module(2) diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..af98878e9e70663e0a51afefde7f8f1afca6e1cd --- /dev/null +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -0,0 +1,113 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + param_list = [p for p in model.parameters()] + chunk_list = chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + chunk_manager.access_chunk(chunk) + + for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) + + +@parameterize('init_device', [get_current_device()]) +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('keep_gather', [False, True]) +@parameterize('model_name', ['gpt2', 'bert', 'albert']) +@parameterize('use_grad_checkpoint', [False, True]) +def exam_gpt_fwd_bwd(placement_policy, + keep_gather, + model_name: str, + use_grad_checkpoint: bool = False, + init_device=get_current_device()): + + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + set_seed(42) + with ColoInitContext(device=init_device): + model = model_builder(use_grad_checkpoint) + + set_seed(42) + torch_model = model_builder(use_grad_checkpoint).cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + + pg = ProcessGroup() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + if i > 0: + break + input_ids, label = input_ids.cuda(), label.cuda() + + torch_optim.zero_grad() + zero_optim.zero_grad() + + # set random seed is same as torch_model.eval() + set_seed(42) + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + + assert torch.equal(torch_loss, loss) + + check_grad(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_gpt_fwd_bwd() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(4) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py new file mode 100644 index 0000000000000000000000000000000000000000..7fce84a5099ae589aa04c281e2b68b627a5b5edc --- /dev/null +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -0,0 +1,108 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + +# run gemini use the runtime memory tracer + + +@parameterize('placement_policy', ['auto']) +@parameterize('keep_gather', [False]) +@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2']) +@parameterize('use_grad_checkpoint', [False, True]) +def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(use_grad_checkpoint) + + print(f'model_name {model_name}') + runtime_mem_tracer = RuntimeMemTracer(model) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 0: + break + input_ids, label = input_ids.cuda(), label.cuda() + + # mem tracing + if i == 0: + run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) + memstats = runtime_mem_tracer.memstats() + runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list + print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) + print('runtime tracer: ', runtime_tracer_non_model_data) + print([memstats.param_used_step(p) for p in model.parameters()]) + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + pg = ProcessGroup() + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + # print(f'iteration {i}') + if i > 4: + break + input_ids, label = input_ids.cuda(), label.cuda() + + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, model) + + gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + + # print('gemini non model data:', gemini_non_model_data) + + assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ + f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gemini_use_rmt() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gemini_use_rmt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gemini_use_rmt(1) diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..185521edb35750a4bf9fbd3a5f0738fadf58ff27 --- /dev/null +++ b/tests/test_gemini/update/test_grad_clip.py @@ -0,0 +1,117 @@ +from functools import partial +from time import time + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', ['gpt2']) +def exam_grad_clipping(placement_policy, model_name: str): + set_seed(1912) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + + model.train() + torch_model.train() + + set_seed(dist.get_rank() * 3 + 128) + for i, (data, label) in enumerate(train_dataloader): + if i > 2: + break + data = data.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) + loss = run_fwd_bwd(model, data, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) + torch_optim.step() + zero_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_grad_clipping() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_grad_clip(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_clip(2) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1d488a0b204adcf92b264689b198b36911b5a7 --- /dev/null +++ b/tests/test_gemini/update/test_optim.py @@ -0,0 +1,169 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + +# this model is large enough to slice to chunks +TEST_MODELS = ['gpt2'] +# these models are too small, all parameters in these models are compacted into one chunk +EXAMPLE_MODELS = ['albert', 'hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers'] + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', TEST_MODELS) +def exam_model_step(placement_policy, model_name: str): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + input_ids, label = input_ids.cuda(), label.cuda() + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', EXAMPLE_MODELS) +def exam_tiny_example(placement_policy, model_name: str): + set_seed(2008) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + + input_ids = input_ids.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_model_step() + exam_tiny_example() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_optim(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_optim(1) diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b4e207f16f89275f3a685895ff8e38383cbb56 --- /dev/null +++ b/tests/test_gemini/update/test_search.py @@ -0,0 +1,65 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +def init_1d_row_spec(model, pg: ProcessGroup): + tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + if 'weight' in n and 'ln' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*tensor_spec) + + +def exam_search_chunk_size(): + + world_size = torch.distributed.get_world_size() + pg_tp = ProcessGroup(tp_degree=world_size) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # make sure torch_model and model has the same parameter values + with ColoInitContext(device=get_current_device()): + model = model_builder() + init_1d_row_spec(model, pg_tp) + config_dict, _ = search_chunk_configuration(model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True) + + for key in config_dict: + chunk_size = config_dict[key]['chunk_size'] + if world_size == 1: + assert chunk_size == 31616 + else: + assert chunk_size == 1024 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_search_chunk_size() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_search(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_search(4) diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0c6e37a7e84aa7e83d637c9f5cdf4fcea73605 --- /dev/null +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -0,0 +1,110 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, keep_gathered, model_name: str): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + torch_dict = torch_model.state_dict() + model.load_state_dict(torch_dict, strict=False) + zero_dict = model.state_dict(only_rank_0=False) + + for key, value in torch_dict.items(): + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + exam_load_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_ddp(1) diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..7f53415bf22ce603c640bd0cae811441177851f8 --- /dev/null +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -0,0 +1,95 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_zero_optim_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters()) + optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + + set_seed(dist.get_rank() * 3 + 128) + model.train() + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 0: + break + optim.zero_grad() + logits = model(input_ids) + logits = logits.float() + loss = criterion(logits, input_ids) + optim.backward(loss) + optim.step() + + optim_state_dict = optim.state_dict() + optim.load_state_dict(optim_state_dict) + new_state = optim.state_dict()['state'] + org_state = optim_state_dict['state'] + + for k, v in org_state.items(): + w = new_state[k] + for n, m in v.items(): + if isinstance(m, torch.Tensor): + o = w[n] + if m.device != o.device: + o = o.to(m.device) + assert torch.equal(m, o) + else: + assert m == w[n] + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_zero_optim_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_optim(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_optim(1) diff --git a/tests/test_layers/test_1d/checks_1d/__init__.py b/tests/test_layers/test_1d/checks_1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..668b8a334800753d9848347a8d88cba66b605de6 --- /dev/null +++ b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -0,0 +1,552 @@ +import torch +import torch.distributed as dist +from torch.nn import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import ( + Classifier1D, + Embedding1D, + Linear1D_Col, + Linear1D_Row, + VanillaClassifier, + VocabParallelClassifier1D, + VocabParallelCrossEntropyLoss1D, + VocabParallelEmbedding1D, +) +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal + + +def check_linear_col(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = B.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('linear_col forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_col backward: pass') + + +def check_linear_row(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_row backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + env.parallel_input_1d = False + parallel_input_1d = env.parallel_input_1d + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = B_master.clone() + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + if parallel_input_1d: + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + else: + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + if parallel_input_1d: + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = C_master.clone() + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + criterion = VocabParallelCrossEntropyLoss1D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=-1)[i] + out = out.clone() + out.requires_grad = True + + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel loss backward: pass') + + +@torch.no_grad() +def check_linear_row_stream_inference(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + stream_chunk_num = 4 + assert HIDDEN_SIZE % stream_chunk_num == 0 + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + layer.chunk_weight() + layer.eval() + + out = layer(A) + + A_master = A_master.clone() + W_master = W_master.clone() + B_master = B_master.clone() + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') diff --git a/tests/test_layers/test_1d/checks_1d/common.py b/tests/test_layers/test_1d/checks_1d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..76a255bc90954ffe7c09a63c2a1c8578a04b78fc --- /dev/null +++ b/tests/test_layers/test_1d/checks_1d/common.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 4 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +IMG_SIZE = 16 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..897590f0d9c8f0179a61b9c311af232eb4076171 --- /dev/null +++ b/tests/test_layers/test_1d/test_1d.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from checks_1d.check_layer_1d import * + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + check_linear_col() + check_linear_row() + check_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_vocab_parallel_loss() + + check_linear_row_stream_inference() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_1d(): + world_size = 4 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_1d() diff --git a/tests/test_layers/test_2d/checks_2d/__init__.py b/tests/test_layers/test_2d/checks_2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e030e473a36311250bd0310b94261a46b77d561a --- /dev/null +++ b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -0,0 +1,741 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, + VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D, + VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D) +from colossalai.utils import get_current_device, print_rank_0 + +from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal) + + +def check_linear(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = HIDDEN_SIZE + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = Linear2D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + W = torch.chunk(W, DEPTH, dim=-1)[j] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=-1)[j] + B = torch.chunk(B, DEPTH, dim=-1)[i] + B = B.clone() + B.requires_grad = True + + layer.weight.data.copy_(W) + layer.bias.data.copy_(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master) + B_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('linear forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear backward: pass') + + +def check_layernorm(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + EPS = 1e-12 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layernorm = LayerNorm2D(INPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + out = layernorm(A) + + A_master = A_master.clone() + A_master.requires_grad = True + E_master = torch.sum(A_master, dim=-1, keepdim=True) + E_master /= INPUT_SIZE + V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True) + V_master /= INPUT_SIZE + V_master = V_master - E_master * E_master + V_master = 1.0 / torch.sqrt(V_master + EPS) + C_master = (A_master - E_master) * V_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('layer norm forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + print_rank_0('layer norm backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_patch_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j] + proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j] + proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(A) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('patch embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j] + cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i] + check_equal(cls_grad, layer.cls_token.grad) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j] + pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i] + check_equal(pos_grad, layer.pos_embed.grad) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.weight.grad) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + bias_grad = torch.chunk(bias_grad, DEPTH)[i] + check_equal(bias_grad, layer.bias.grad) + print_rank_0('patch embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = NUM_CLASSES + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randint(5, A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randint(5, W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[j] + W = torch.chunk(W, DEPTH, dim=-1)[i] + W = W.clone() + layer.weight.data.copy_(W) + # W.requires_grad = True + + B_shape = (OUTPUT_SIZE, ) + B_master = torch.randint(5, B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + # B = torch.chunk(B_master, DEPTH, dim=0)[j] + B = B_master.clone() + layer.bias.data.copy_(B) + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + # C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + # grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + # B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + bias = torch.chunk(bias, DEPTH)[i] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH)[j] + B_grad = torch.chunk(B_grad, DEPTH)[i] + check_equal(B_grad, layer.bias.grad) + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_loss(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + criterion = CrossEntropyLoss2D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + check_equal(out_grad, out.grad) + print_rank_0('cross entropy loss backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + criterion = VocabParallelCrossEntropyLoss2D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=-1)[j] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel cross entropy loss backward: pass') + + +# def check_attention(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + +# layer = TransformerSelfAttention2D( +# HIDDEN_SIZE, +# NUM_ATTENTION_HEADS, +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, DEPTH, dim=0)[i] +# A = torch.chunk(A, DEPTH, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH) +# print_rank_0('self attention forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('self attention backward: pass') + +# def check_mlp(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE + +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + +# layer = TransformerMLP2D( +# HIDDEN_SIZE, +# dropout_prob=0.5, +# act_func='gelu', +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, DEPTH, dim=0)[i] +# A = torch.chunk(A, DEPTH, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# out = layer(A) +# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH) +# print_rank_0('mlp forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('mlp backward: pass') + +# def check_transformerlayer(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + +# layer = TransformerLayer2D(HIDDEN_SIZE, +# NUM_ATTENTION_HEADS, +# act_func='gelu', +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, DEPTH, dim=0)[i] +# A = torch.chunk(A, DEPTH, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH) +# print_rank_0('transformerlayer forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('transformerlayer backward: pass') diff --git a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e37b1ec3097b29b55a4744111131ef3bfdea44 --- /dev/null +++ b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D +from colossalai.utils import get_current_device +from colossalai.utils import print_rank_0 +from .common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH + + +def check_AB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) + + out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, B_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + # check forward correctness + check_equal(out, C) + print_rank_0('AB forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + # check backward correctness + check_equal(A_grad, A.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + # check backward correctness + check_equal(B_grad, B.grad) + print_rank_0('AB backward: pass') + + +def check_ABT(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + device = get_current_device() + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + C_master = C_master.clone() + C_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + A_master = torch.matmul(C_master, B_master.transpose(0, 1)) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + check_equal(out, A) + print_rank_0('ABT forward: pass') + + grad_shape = A_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + # backward + out.backward(grad) + + A_master.backward(grad_master) + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] + C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] + check_equal(C_grad, C.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + check_equal(B_grad, B.grad) + print_rank_0('ABT backward: pass') + + +def check_ATB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + device = get_current_device() + dtype = torch.float + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = C_master.clone() + C_master.requires_grad = True + B_master = torch.matmul( + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + check_equal(out, B) + print_rank_0('ATB forward: pass') + + grad_shape = B_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + out.backward(grad) + + B_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] + C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] + check_equal(C_grad, C.grad) + print_rank_0('ATB backward: pass') diff --git a/tests/test_layers/test_2d/checks_2d/common.py b/tests/test_layers/test_2d/checks_2d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..8c855c18bc26c7e06507f2b180d87cb0cae8b67f --- /dev/null +++ b/tests/test_layers/test_2d/checks_2d/common.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..da235d0cf168cf743257d433e43fd91d69699e6a --- /dev/null +++ b/tests/test_layers/test_2d/test_2d.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, + check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, + check_vocab_parallel_loss) +from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) + + +def check_operations(): + check_AB() + check_ABT() + check_ATB() + + +def check_layer(): + check_linear() + check_layernorm() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_loss() + check_vocab_parallel_loss() + + +def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + # check_operations() + check_layer() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_2d(): + world_size = 4 + run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_2d() diff --git a/tests/test_layers/test_2p5d/checks_2p5d/__init__.py b/tests/test_layers/test_2p5d/checks_2p5d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f551093b1ef782f8bef64d4241c1400aa6bdde --- /dev/null +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -0,0 +1,754 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, + PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D, + VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D) +from colossalai.utils import get_current_device, print_rank_0 +from torch.nn import Parameter + +from .common import * + + +def check_linear(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i] + W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] + B = B.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + bias = layer.bias + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master) + B_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('linear forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear backward: pass') + + +def check_layernorm(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + EPS = 1e-12 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + out = layernorm(A) + + A_master = A_master.clone() + A_master.requires_grad = True + E_master = torch.sum(A_master, dim=-1, keepdim=True) + E_master /= INPUT_SIZE + V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True) + V_master /= INPUT_SIZE + V_master = V_master - E_master * E_master + V_master = 1.0 / torch.sqrt(V_master + EPS) + C_master = (A_master - E_master) * V_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('layer norm forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + print_rank_0('layer norm backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_patch_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j] + proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j] + proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(A) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('patch embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j] + cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(cls_grad, layer.cls_token.grad) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j] + pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(pos_grad, layer.pos_embed.grad) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + check_equal(B_grad, layer.weight.grad) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j] + bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i] + check_equal(bias_grad, layer.bias.grad) + print_rank_0('patch embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = NUM_CLASSES + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + + layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randint(5, A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randint(5, W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + # W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] + W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] + W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i] + W = W.clone() + layer.weight.data.copy_(W) + # W.requires_grad = True + + B_shape = (OUTPUT_SIZE, ) + B_master = torch.randint(5, B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + # B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] + B = B_master.clone() + layer.bias.data.copy_(B) + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + # B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, TESSERACT_DIM)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j] + if i == 0: + check_equal(B_grad, layer.bias.grad) + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i] + embed.weight.data.copy_(weight) + + layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_loss(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + criterion = CrossEntropyLoss2p5D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] + check_equal(out_grad, out.grad) + print_rank_0('cross entropy loss backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + criterion = VocabParallelCrossEntropyLoss2p5D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] + out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel cross entropy loss backward: pass') + + +# def check_attention(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) +# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + +# layer = TransformerSelfAttention2p5D( +# HIDDEN_SIZE, NUM_ATTENTION_HEADS, +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5, +# dtype=dtype, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] +# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM) +# print_rank_0('self attention forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('self attention backward: pass') + +# def check_mlp(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE + +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) +# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + +# layer = TransformerMLP2p5D( +# HIDDEN_SIZE, +# mlp_ratio=1, +# dropout_prob=0.5, +# act_func='gelu', +# dtype=dtype, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] +# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# out = layer(A) +# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM) +# print_rank_0('mlp forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('mlp backward: pass') + +# def check_transformerlayer(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) +# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + +# layer = TransformerLayer2p5D( +# HIDDEN_SIZE, +# NUM_ATTENTION_HEADS, +# act_func='gelu', +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5, +# dtype=dtype, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] +# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM) +# print_rank_0('transformerlayer forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('transformerlayer backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c3b02fccba589507bf8e2af25846767636c734 --- /dev/null +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -0,0 +1,216 @@ +import torch + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \ + Matmul_ATB_2p5D +from colossalai.utils import get_current_device +from colossalai.utils import print_rank_0 +from .common import * + + +def check_AB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) + out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, B_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + # check forward correctness + check_equal(out, C) + print_rank_0('AB forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + # check backward correctness + check_equal(A_grad, A.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + # check backward correctness + check_equal(B_grad, B.grad) + print_rank_0('AB backward: pass') + + +def check_ABT(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + device = get_current_device() + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM, + (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k, + ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + C_master = C_master.clone() + C_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + A_master = torch.matmul(C_master, B_master.transpose(0, 1)) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + check_equal(out, A) + print_rank_0('ABT forward: pass') + + grad_shape = A_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + # backward + out.backward(grad) + + A_master.backward(grad_master) + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(C_grad, C.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(B_grad, B.grad) + print_rank_0('ABT backward: pass') + + +def check_ATB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + device = get_current_device() + dtype = torch.float + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), + i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = C_master.clone() + C_master.requires_grad = True + B_master = torch.matmul( + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + check_equal(out, B) + print_rank_0('ATB forward: pass') + + grad_shape = B_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + out.backward(grad) + + B_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(C_grad, C.grad) + print_rank_0('ATB backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_layers/test_2p5d/checks_2p5d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..aff85f109666d7cdf9e65173eda851368c39694c --- /dev/null +++ b/tests/test_layers/test_2p5d/checks_2p5d/common.py @@ -0,0 +1,14 @@ +import torch + +TESSERACT_DIM = 2 +TESSERACT_DEP = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) \ No newline at end of file diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..365e2d934df8d85832791d0fc5a91113c3c773f4 --- /dev/null +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -0,0 +1,62 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from checks_2p5d.check_layer_2p5d import * +from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB + +CONFIG = dict(parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=4, mode='2.5d', depth=1), +),) + + +def check_operations(): + check_AB() + check_ABT() + check_ATB() + + +def check_layer(): + check_linear() + check_layernorm() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_loss() + check_vocab_parallel_loss() + + +def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + check_operations() + check_layer() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_2p5d(): + world_size = 4 + run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_2p5d() diff --git a/tests/test_layers/test_3d/checks_3d/__init__.py b/tests/test_layers/test_3d/checks_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9e199e22e8aca3d45d27ad12b34e6337c67af857 --- /dev/null +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -0,0 +1,864 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import time + +import torch +from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.core import global_context +from colossalai.logging import get_dist_logger +from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, + VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D, + VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D) +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal + + +def check_linear(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True) + layer = layer.to(device) + layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data.transpose(0, 1).contiguous() + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[k] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad.transpose(0, 1) + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_layernorm(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + norm = LayerNorm3D(INPUT_SIZE, eps=1e-6) + norm = norm.to(device) + norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6) + norm_master = norm_master.to(device) + + weight_master = norm_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH)[k] + norm.weight.data.copy_(weight) + bias_master = norm_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[k] + norm.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = norm(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), logger) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = norm_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + + bias_grad = norm_master.weight.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad))) + + bias_grad = norm_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_classifier_no_given_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True) + layer = layer.to(device) + + layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + layer.bias.data.copy_(bias_master) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format( + rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + if j == k: + logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, layer.weight.grad))) + else: + logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( + rank, layer.weight.grad is None)) + + bias_grad = layer_master.bias.grad + logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_classifier_no_given_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(device) + + layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format( + rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_classifier_given_embed_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + embed.weight.data.copy_(weight) + + layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(embed(A)) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + if j == k: + logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, embed.weight.grad))) + else: + logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( + rank, embed.weight.grad is None)) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_classifier_given_embed_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(device) + + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(embed(A)) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, + check_equal(B_grad, + embed.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_patch_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[k] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, DEPTH)[k] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), logger) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] + logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] + logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank, + check_equal(pos_grad, layer.pos_embed.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] + logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank, + check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank, + check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) + layer = layer.to(device) + layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), + ranks=[0]) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) + layer = layer.to(device) + layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), + ranks=[0]) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, + check_equal(B_grad, + layer.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_loss(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + + criterion = CrossEntropyLoss3D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=0)[j] + out = out.clone() + out.requires_grad = True + + fwd_start = time.time() + loss = criterion(out, target_master) + fwd_end = time.time() + logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), + fwd_end - fwd_start), + ranks=[0]) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + + bwd_start = time.time() + loss.backward() + bwd_end = time.time() + logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + loss_master.backward() + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] + logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_loss(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + criterion = VocabParallelCrossEntropyLoss3D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=-1)[k] + out = torch.chunk(out, DEPTH, dim=0)[j] + out = out.clone() + out.requires_grad = True + + fwd_start = time.time() + loss = criterion(out, target_master) + fwd_end = time.time() + logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), + ranks=[0]) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + + bwd_start = time.time() + loss.backward() + bwd_end = time.time() + logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + loss_master.backward() + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k] + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..afb19c4745cc72cd66a4d5a11239ed6f10f68d14 --- /dev/null +++ b/tests/test_layers/test_3d/checks_3d/common.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +NUM_BLOCKS = 2 +IMG_SIZE = 16 +VOCAB_SIZE = 16 + + +def check_equal(A, B): + eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) + assert eq, f"\nA = {A}\nB = {B}" + return eq \ No newline at end of file diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..29a8b3aea239f47c87c61899981c3ea762d7e91d --- /dev/null +++ b/tests/test_layers/test_3d/test_3d.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus +from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, + check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, + check_vocab_parallel_loss) + +CONFIG = dict( + parallel=dict( + pipeline=1, + tensor=dict(mode='3d', size=8), + ), + seed=42, +) + + +def check_layer(): + check_linear() + check_layernorm() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_loss() + check_vocab_parallel_loss() + + +def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + check_layer() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_3d(): + world_size = 8 + run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_3d() diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..cff9072c7a06ea763096870bddf4b179cd254f90 --- /dev/null +++ b/tests/test_layers/test_cache_embedding.py @@ -0,0 +1,373 @@ +import pytest +from functools import partial + +import numpy as np +import random + +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ + ColoTensor, ColoTensorSpec +from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \ + ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig +from typing import List + +NUM_EMBED, EMBED_DIM = 10, 8 +BATCH_SIZE = 8 + + +def set_seed(seed): + """ + To achieve reproducible results, it's necessary to fix random seeds + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def synthesize_1d_sparse_feature( + batch_size, + num_embed, + device, +): + indices_in_batch = batch_size * 2 + indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) + offsets = torch.from_numpy( + np.array([ + 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch + ])).to(device).long() + return indices, offsets + + +@pytest.mark.skip +def test_cachemgr(): + model = torch.nn.EmbeddingBag(10000, 128) + # 10 chunks, 5 in cuda + mgr = CachedParamMgr(model.weight.detach(), 5) + assert mgr.cuda_row_num == 5 + + mgr._admit(1) + assert not mgr._chunk_in_cuda(2) + assert mgr._chunk_in_cuda(1) + + # print(mgr.cached_chunk_table) + mgr._admit(8) + + # now 3 chunk is available + assert mgr.cuda_available_chunk_num == 3 + + mgr._evict() + assert mgr.cuda_available_chunk_num == 4 + + mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0)) + mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0)) + # print(mgr.cached_chunk_table) + # mgr.print_comm_stats() + + mgr.flush() + assert mgr.cuda_available_chunk_num == 5 + + +def test_reorder_with_freq(): + num_embed = 100 + chunk_size = 1 + num_chunk = 5 + + idx_map = torch.randint(10000, size=(num_embed,)) + sorted_idx = torch.argsort(idx_map, descending=True).tolist() + chunkid, offset_in_chunk = [], [] + for i in range(num_embed): + idx = sorted_idx.index(i) + chunkid.append(idx // chunk_size) + offset_in_chunk.append(idx % chunk_size) + + dev = torch.device('cuda') + chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) + offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) + + weight = torch.rand(num_embed, 2) + mgr = CachedParamMgr(weight, num_chunk) + + mgr.reorder(idx_map) + + indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) + mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') + mgr_offsets = torch.remainder(indices, chunk_size) + assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" + assert torch.allclose(offset_in_chunk, mgr_offsets), \ + f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" + + +@pytest.mark.parametrize('use_LFU', [True, False]) +def test_freq_aware_embed(use_LFU: bool): + device = torch.device('cuda', 0) + evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET + model = CachedEmbeddingBag(NUM_EMBED, + EMBED_DIM, + mode='mean', + include_last_offset=True, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), + ids_freq_mapping=None, + evict_strategy=evict_strategy).to(device) + + assert model.weight.shape[0] == NUM_EMBED + ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), + mode='mean', + include_last_offset=True, + freeze=False) + + assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device) + res = model(indices, offsets) + ref_res = ref_model(indices, offsets) + assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}" + + grad = torch.rand_like(res) + # comparing gradient here is nontrivial + res.backward(grad) + ref_res.backward(grad) + optimizer.step() + optimizer.zero_grad() + + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + model_weight = model.weight.detach().to(device) + ref_weight = ref_model.weight.detach() + assert torch.allclose(model_weight, ref_weight), \ + f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" + + +@pytest.mark.parametrize('init_freq', [True, False]) +def test_lfu_strategy(init_freq: bool): + # minimal test to check behavior + Bag = CachedEmbeddingBag(5, + 5, + cache_ratio=3 / 5, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU) + + # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) + offsets = torch.tensor([0], device="cuda:0") + + # prepare frequency learning info: + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + + # check strategy + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit + + assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ + "LFU strategy behavior failed" + + +def gather_tensor(tensor, rank, world_size): + gather_list = [] + if rank == 0: + gather_list = [torch.empty_like(tensor) for _ in range(world_size)] + + torch.distributed.gather(tensor, gather_list, dst=0) + return gather_list + + +def run_parallel_freq_aware_embed_tablewise(rank, world_size): + if world_size != 2: + return + device = torch.device('cuda', torch.cuda.current_device()) + + # initialize weight + # 3 feature tables. idx: 0~5, 6~10, 11~17 + weight_tables = torch.rand(18, 5) + weight_table1 = weight_tables[0:6] + weight_table2 = weight_tables[6:11] + weight_table3 = weight_tables[11:18] + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=6, + cuda_row_num=4, + assigned_rank=0, + initial_weight=weight_table1.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=5, + cuda_row_num=4, + assigned_rank=0, + initial_weight=weight_table2.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=7, + cuda_row_num=4, + assigned_rank=1, + initial_weight=weight_table3.clone().detach().cpu())) + if rank == 0: + _weight = torch.cat([weight_table1, weight_table2], 0) + else: + _weight = weight_table3 + model = ParallelCachedEmbeddingBagTablewise( + embedding_bag_config_list, + embedding_dim=5, + _weight=_weight, + include_last_offset=True, + cache_ratio=0.5, + buffer_size=0, + evict_strategy=EvictionStrategy.LFU, + ) + # explain + ''' + batch feature 1 feature 2 feature 3 + input0 [1,2,3] [6,7] [] + input1 [] [9] [13,15] + input2 [1,5] [6,8] [11] + โ†‘ โ†‘ โ†‘ + rank 0 rank 0 rank 1 + in KJT format + ''' + res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + already_split_along_rank=False) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) + rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) + if rank == 0: + fake_grad = rand_grad[0:2] + else: + fake_grad = rand_grad[2:] + res.backward(fake_grad) + optimizer.step() + optimizer.zero_grad() + + # check correctness + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), + include_last_offset=True, + freeze=False).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) + ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) + ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + ref_res.backward(ref_fake_grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + recover_weight = model.cache_weight_mgr.weight.to(device) + ref_weight = ref_model.weight.detach()[:11] + assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" + + +def run_parallel_freq_aware_embed_columnwise(rank, world_size): + device = torch.device('cuda', torch.cuda.current_device()) + + num_embed = 100 + embed_dim = 16 + batch_size = 4 + + set_seed(4321) + weight = torch.rand(num_embed, embed_dim) + coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None) + + # initialize the tensor spec for the embedding weight parameter, + # which is an ColoParameter. + coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) + coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) + + model = ParallelCachedEmbeddingBag.from_pretrained( + coloweight, + include_last_offset=True, + freeze=False, + cache_ratio=batch_size * 2 / num_embed, + ) + + assert model.cache_weight_mgr.weight.device.type == 'cpu' + assert model.cache_weight_mgr.cuda_cached_weight.requires_grad + weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] + print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") + assert torch.allclose(weight_in_rank, + model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), + include_last_offset=True, + freeze=False).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + set_seed(4321) + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device) + res = model(indices, offsets) + + grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device) + grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank] + res.backward(grad_in_rank) + + optimizer.step() + optimizer.zero_grad() + + res_list = gather_tensor(res.detach(), rank, world_size) + + if rank == 0: + ref_res = ref_model(indices, offsets) + recover_res = torch.cat(res_list, dim=0) + + assert torch.allclose(ref_res, recover_res) + + ref_res.backward(grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size) + if rank == 0: + recover_weight = torch.cat(weight_list, dim=1) + assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # run_parallel_freq_aware_embed_columnwise(rank, world_size) + run_parallel_freq_aware_embed_tablewise(rank, world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_parallel_freq_aware_embed(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + # test_freq_aware_embed(True) + test_parallel_freq_aware_embed(2) + # test_lfu_strategy(False) diff --git a/tests/test_layers/test_sequence/checks_seq/__init__.py b/tests/test_layers/test_sequence/checks_seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7b999d43731ae5b5cd3f7cb87eecc6c1585fc4 --- /dev/null +++ b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -0,0 +1,21 @@ +import torch + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn import TransformerSelfAttentionRing +from colossalai.utils import get_current_device + + +def check_selfattention(): + WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE) + SUB_SEQ_LENGTH = 8 + BATCH = 4 + HIDDEN_SIZE = 16 + + layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) + layer = layer.to(get_current_device()) + + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + attention_mask = torch.randint(low=0, high=2, + size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) + out = layer(hidden_states, attention_mask) diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..3862c4ccd439ecd0805c543b1e6a1809a20ad65a --- /dev/null +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -0,0 +1,143 @@ +import colossalai +import colossalai.nn as col_nn +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import pytest + +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.testing import rerun_if_address_is_in_use +from functools import partial + +CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) + + +def check_ring_qk(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 32 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + + # set autograd attributes + q.requires_grad = True + k.requires_grad = True + q.retain_grad() + k.retain_grad() + sub_q.requires_grad = True + sub_k.requires_grad = True + sub_q.retain_grad() + sub_k.retain_grad() + + # compute master attention scores + a = torch.matmul(q, k.transpose(2, 1)) + + # compute distributed attention scores + ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply + sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) + + # check master and distributed attetion scores + sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) + + # run master backward + a.retain_grad() + a.mean().backward() + + # run distributed backward + partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + torch.autograd.backward(sub_a, partial_master_a_grad) + + # check master and distributed grads + partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ + 'attention score cannot match' + + +def check_ring_av(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 16 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda() + v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + + # set autograd attributes + a.requires_grad = True + v.requires_grad = True + a.retain_grad() + v.retain_grad() + sub_a.requires_grad = True + sub_v.requires_grad = True + sub_a.retain_grad() + sub_v.retain_grad() + + # compute master attention scores + out = torch.matmul(a, v) + + # compute distributed attention scores + ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply + sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length) + + # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') + + # check master and distributed output + sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) + + # # run master backward + out.retain_grad() + out.mean().backward() + + # # run distributed backward + partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + torch.autograd.backward(sub_out, partial_master_out_grad) + + # # check master and distributed grads + partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ + 'attention output cannot match' + + +def run_test(rank, world_size): + colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500) + + # check_ring_qk(rank, world_size) + check_ring_av(rank, world_size) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_sequence(): + world_size = 4 + run_func = partial(run_test, world_size=world_size) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sequence() diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b9a55277c6fa5b3a94e0c3440e80263101255f --- /dev/null +++ b/tests/test_moe/test_grad_handler.py @@ -0,0 +1,74 @@ +from functools import partial +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +import torch.distributed as dist +import colossalai +from colossalai.utils import free_port, get_current_device +from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils.moe import sync_moe_model_param +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use + +BATCH_SIZE = 4 +DIM = 16 +CONFIG = dict() + + +def run_test(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + expert_module = nn.Linear + expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) + + MOE_CONTEXT.setup(42) # MOE initialization + noisy_func = UniformNoiseGenerator() + router = Top1Router(noisy_func=noisy_func) + num_experts_list = [1, 2, 4] + layer_list = [] + for num_experts in num_experts_list: + exp = Experts(expert_module, num_experts, **expert_factor) + moe_layer = MoeLayer(DIM, num_experts, router, exp) + layer_list.append(moe_layer) + + model = nn.ModuleList(layer_list) + model = model.to(get_current_device()) + sync_moe_model_param(model) + + dist_dict = MOE_CONTEXT.parallel_info_dict + assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group) + # MoE model synchronization passed + + grad_handler = MoeGradientHandler(model, 0) + + rank = dist.get_rank() + torch.cuda.manual_seed(78 + rank) + data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + grad = torch.randn_like(data) + + MOE_CONTEXT.reset_loss() + for layer in layer_list: + data, _ = layer(data) + data.backward(grad) + grad_handler.handle_gradient() + + assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group) + + assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group) + # MoE grad handler test passed + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_grad_handler(): + world_size = 4 + run_func = partial(run_test, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..62f9241642b90e058547f93c62aafcc523c0a72b --- /dev/null +++ b/tests/test_moe/test_kernel.py @@ -0,0 +1,105 @@ +from functools import partial +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import free_port, get_current_device +from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.testing import rerun_if_address_is_in_use + +BATCH_SIZE = 16 +NUM_EXPERTS = 4 +CONFIG = dict() + + +def check_equal(tensor_a, tensor_b, atol=1e-06): + assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True + + +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router): + # Here we do not need TF32, since it brings absolute error on results + torch.backends.cuda.matmul.allow_tf32 = False + + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + + MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.reset_loss() + torch.manual_seed(rs + local_rank) # set each process has different random seed + + # get randomized data + tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + + expert_module = nn.Linear + expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) + expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) + layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) + layer = layer.to(get_current_device()) + if data_type == torch.float16: + layer = layer.half() + + # use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine + layer.use_kernel = False + old_out, _ = layer(tokens) + ech = old_out.shape + grad = torch.randn(ech, device=get_current_device()) + old_out.backward(grad) # get gradient + + # save all results + o_tk_grad = tokens.grad.data.clone() + o_gt_grad = layer.gate_weight.grad.data.clone() + + # reset all gradients + tokens.grad.zero_() + layer.gate_weight.grad.zero_() + + layer.use_kernel = True + new_out, _ = layer(tokens) # get ouputs through colossal kernel + + if data_type == torch.float32: + check_equal(old_out, new_out) + else: + check_equal(old_out, new_out, 1e-2) + # forward function passed + + new_out.backward(grad) # get new type gradient + n_tk_grad = tokens.grad.data.clone() + n_gt_grad = layer.gate_weight.grad.data.clone() + + if data_type == torch.float32: + check_equal(o_tk_grad, n_tk_grad) + else: + check_equal(o_tk_grad, o_tk_grad, 1e-2) + # tokens gradient is correct + + if data_type == torch.float32: + check_equal(o_gt_grad, n_gt_grad, 5e-05) + else: + check_equal(o_gt_grad, n_gt_grad, 2e-01) + # bias gradient is correct + + +@pytest.mark.dist +@pytest.mark.parametrize("rs", [131]) +@pytest.mark.parametrize("hidden_size", [32, 144]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) +@pytest.mark.parametrize("router", [Top1Router, Top2Router]) +@rerun_if_address_is_in_use() +def test_moe_kernel(rs, hidden_size, data_type, router): + world_size = 4 + run_func = partial(run_routing, + world_size=world_size, + port=free_port(), + rs=rs, + hidden_size=hidden_size, + data_type=data_type, + router=router) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_kernel(2, 256, torch.float16, Top2Router) diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py new file mode 100644 index 0000000000000000000000000000000000000000..d54e2afdacdb7ceeeaa45c40403d23494585b302 --- /dev/null +++ b/tests/test_moe/test_moe_colo_init.py @@ -0,0 +1,63 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import parameterize +from colossalai.utils import free_port +from colossalai.context import MOE_CONTEXT +from colossalai.tensor import ColoParameter +from colossalai.utils.model.colo_init_context import ColoInitContext + +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import get_current_device + +from tests.test_zero.common import CONFIG +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print + + +@parameterize("init_device_type", ['cpu', 'cuda']) +def exam_moe_colo_init(init_device_type): + world_size = dist.get_world_size() + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + with ColoInitContext(device=init_device): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) + + if hasattr(param, "moe_info"): + param.set_process_group(param.moe_info.pg) + + if hasattr(param, "moe_info"): + assert param.process_group.dp_world_size() == param.moe_info.dp_size + else: + assert param.process_group.dp_world_size() == world_size + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_colo_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_colo_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py new file mode 100644 index 0000000000000000000000000000000000000000..3126f59e246e718dd81f2f6a7299431867fd43d1 --- /dev/null +++ b/tests/test_moe/test_moe_group.py @@ -0,0 +1,71 @@ +from functools import partial +import pytest +import torch.nn as nn +import torch.multiprocessing as mp +import torch.distributed as dist +import colossalai +from colossalai.utils import free_port, get_current_device +from colossalai.nn.layer.moe import Experts +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils.moe import sync_moe_model_param +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use + +D_MODEL = 4 +D_FF = 8 +CONFIG = dict() + + +def run_test(rank, port): + world_size = 4 + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + expert_module = nn.Linear + expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) + + MOE_CONTEXT.setup(42) # MOE environment initialization + exp0 = Experts(expert_module, 1, **expert_factor) + exp1 = Experts(expert_module, 2, **expert_factor) + exp2 = Experts(expert_module, 4, **expert_factor) + exp3 = Experts(expert_module, 8, **expert_factor) + + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 1 + assert exp2.num_local_experts == 1 + assert exp3.num_local_experts == 2 + # experts deployment passed + + parallel_info_dict = MOE_CONTEXT.parallel_info_dict + rank = dist.get_rank() + + assert len(parallel_info_dict) == 3 + assert dist.get_rank(parallel_info_dict[4].ep_group) == rank + assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 + assert dist.get_rank(parallel_info_dict[1].ep_group) == 0 + + assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 + assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 + assert dist.get_rank(parallel_info_dict[1].dp_group) == rank + # group creation passed + + model = nn.ModuleList([exp0, exp1, exp2, exp3]) + model = model.to(get_current_device()) + sync_moe_model_param(model) + + assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group) + assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group) + # MOE experts layout success when ep_size = 1 + + assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group) + assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group) + # MOE experts layout success when ep_size = 2 + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_moe_initialization(): + world_size = 4 + run_func = partial(run_test, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_initialization() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea6ac8170e4ec79d0c6ac9c33a68a0d0f6a64b2 --- /dev/null +++ b/tests/test_moe/test_moe_zero_init.py @@ -0,0 +1,114 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.nn import CheckpointModule +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize +from colossalai.utils import free_port +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer import MoeModule +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) + +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import get_current_device +from tests.test_zero.common import CONFIG + + +class MoeModel(nn.Module): + + def __init__(self, checkpoint: bool = False): + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) + return x + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_moe_zero_init(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_moe_zero_init") + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert hasattr(param, 'colo_attr') + + # the parameters in moe experts and its gate should not be sharded + if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) + else: + assert param.colo_attr.sharded_data_tensor.is_sharded + + # the parameters in moe experts is not replicated + if 'experts' in name: + assert not param.colo_attr.is_replicated + else: + assert param.colo_attr.is_replicated + + if param.colo_attr.param_is_sharded: + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + else: + assert param.colo_attr.data_payload.device.type == 'cuda' + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_zero_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_zero_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d608ebf0718eb9615ae8ec142f6ea2b693d203ba --- /dev/null +++ b/tests/test_moe/test_moe_zero_model.py @@ -0,0 +1,75 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd + + +@parameterize("enable_autocast", [False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): + shard_strategy = shard_strategy_class() + + get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = MoeModel(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: + assert_equal_in_group(p.colo_attr.data_payload) + + model = MoeModel(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) + grad_handler.handle_gradient() + + check_grads_padding(model, zero_model, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9a7bd173900b538010eebed907bb97aaf2520d --- /dev/null +++ b/tests/test_moe/test_moe_zero_optim.py @@ -0,0 +1,124 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.nn.optimizer import CPUAdam +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.common import CONFIG, check_sharded_model_params + + +def _run_step(model, optimizer, data, label, criterion, grad_handler): + model.train() + optimizer.zero_grad() + + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + + if grad_handler is not None: + grad_handler.handle_gradient() + + optimizer.step() + + +@parameterize("cpu_offload", [True]) +@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("reuse_fp16_shard", [True, False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def _run_test_sharded_optim_v2(cpu_offload, + shard_strategy_class, + use_cpuadam, + reuse_fp16_shard, + gpu_margin_mem_ratio=0.0): + shard_strategy = shard_strategy_class() + if use_cpuadam and cpu_offload is False: + return + MOE_CONTEXT.reset_loss() + get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = MoeModel(checkpoint=True) + + zero_model = ShardedModelV2(zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=reuse_fp16_shard) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: + assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) + + model = MoeModel(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda().float() + + if use_cpuadam: + optimizer_class = CPUAdam + optim = optimizer_class(model.parameters(), lr=1e-3) + sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + apex_grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data, label = data.cuda(), label.cuda() + _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) + _run_step(zero_model, sharded_optim, data, label, criterion, None) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) + for param in model.parameters(): + assert not has_inf_or_nan(param) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + _run_test_sharded_optim_v2() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=4) diff --git a/tests/test_ops/test_addmm_tp.py b/tests/test_ops/test_addmm_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..5182868b5bbd17ab35192dfead34e14a2f8092ff --- /dev/null +++ b/tests/test_ops/test_addmm_tp.py @@ -0,0 +1,77 @@ +import colossalai +import torch +import pytest +import torch.nn as nn +import torch.multiprocessing as mp +from colossalai.tensor import ColoTensor, ProcessGroup +from colossalai.tensor import ColoTensorSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from functools import partial +from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + Basically works like a linear layer but the weights are transposed. + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.ones(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def run_with_spec(spec_init_func, split_bias): + model = Conv1D(4, 16).cuda() + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + + weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) + bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + if split_bias: + spec_init_func(bias, pg) + + x = torch.rand(2, 16).cuda() + out = model(x) + colo_out = torch.addmm(bias, x, weight) + colo_out = colo_out.to_replicate() + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False) + run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_addmm_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_addmm_1d(4) diff --git a/tests/test_ops/test_embedding_bag_tp.py b/tests/test_ops/test_embedding_bag_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a1604e5455de06cbb78e78ddb66c0706aba4d4 --- /dev/null +++ b/tests/test_ops/test_embedding_bag_tp.py @@ -0,0 +1,47 @@ +from torch.nn import functional as F +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup +from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d + + +def run_with_spec(spec_init_func): + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + model = torch.nn.EmbeddingBag(10, 4).cuda() + weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + + inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda() + offsets = torch.tensor([0, 4]).cuda() + out = model(inputs, offsets=offsets) + colo_out = F.embedding_bag(inputs, weight, offsets=offsets) + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(split_param_col_tp1d) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_embedding_bag_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_embedding_bag_1d(4) diff --git a/tests/test_ops/test_embedding_tp.py b/tests/test_ops/test_embedding_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..541dc5c0932455a5c5a35ecb3f7bf7a7eea7a17a --- /dev/null +++ b/tests/test_ops/test_embedding_tp.py @@ -0,0 +1,48 @@ +from torch.nn import functional as F +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor +from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d + + +def run_with_spec(spec_init_func, pg: ProcessGroup): + model = torch.nn.Embedding(12, 32).cuda() + weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + + x = torch.tensor((0, 3, 6, 9)).cuda() + out = model(x) + colo_out = F.embedding(x, weight) + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + # compare grad inside a TP group + assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + # config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=world_size) + run_with_spec(split_param_row_tp1d, pg) + run_with_spec(split_param_col_tp1d, pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_embedding_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_embedding_1d(4) diff --git a/tests/test_ops/test_linear_tp.py b/tests/test_ops/test_linear_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..603e98564de8d764ca3288300cc7ac3705640148 --- /dev/null +++ b/tests/test_ops/test_linear_tp.py @@ -0,0 +1,52 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor +from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d + + +def run_with_spec(spec_init_func, split_bias): + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + model = torch.nn.Linear(4, 8).cuda() + weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) + bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + if split_bias: + spec_init_func(bias, pg) + + x = torch.rand(2, 4).cuda() + out = model(x) + colo_out = F.linear(x, weight, bias) + colo_out = colo_out.to_replicate() + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False) + run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_linear_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_1d(4) diff --git a/tests/test_ops/test_loss_func.py b/tests/test_ops/test_loss_func.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6f0e7ab651660a50256cdb2978f3f73fd73eb3 --- /dev/null +++ b/tests/test_ops/test_loss_func.py @@ -0,0 +1,52 @@ +import torch +import pytest +import colossalai +import torch.nn.functional as F +import torch.multiprocessing as mp +from functools import partial +from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec +from colossalai.utils import get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern + + +def check_cross_entropy(): + input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + input_ct.copy_(input_t) + + target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) + input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) + + output = F.cross_entropy(input_t, target) + output_colo = F.cross_entropy(input_shard, target) + assert torch.allclose(output_colo, output) + + output.backward() + output_colo.backward() + + assert torch.allclose(input_t.grad, input_ct.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_cross_entropy() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_loss_func(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_loss_func(1) diff --git a/tests/test_ops/test_op.py b/tests/test_ops/test_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3cf50ff2aa3c0081fcfc95f593ffcd8e5406a5 --- /dev/null +++ b/tests/test_ops/test_op.py @@ -0,0 +1,91 @@ +import torch +import pytest +import colossalai +import torch.nn.functional as F +import torch.multiprocessing as mp +from functools import partial +from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec +from colossalai.utils import get_current_device +from torch.nn import Parameter +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def _run_layer_norm(): + ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device()) + + input_t = torch.randn(3, 2, device=get_current_device()) + + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach(), ColoTensorSpec(pg)) + + # prepare colossalai LN + weight = ColoTensor(Parameter(ln_op.weight.detach()), ColoTensorSpec(pg)) + bias = ColoTensor(Parameter(ln_op.bias.detach()), ColoTensorSpec(pg)) + + output = ln_op(input_t) + output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps) + + assert torch.allclose(output_colo, output) + + torch.mean(output).backward() + torch.mean(output_colo).backward() + + assert torch.allclose(ln_op.weight.grad, weight.grad) + + +def check_spec_eq(tensor, other): + assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) + for k in dir(tensor.dist_spec): + if not k.startswith('__'): + assert hasattr(other.dist_spec, k), f"{k}" + assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k) + + +def check_element_wise_ops(): + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + t = torch.rand(2, 2) + x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()]))) + + check_spec_eq(x, x.cuda()) + assert torch.equal(x.cuda(), t.cuda()) + check_spec_eq(x, torch.abs(x)) + assert torch.equal(torch.abs(x), torch.abs(t)) + check_spec_eq(x, F.sigmoid(x)) + assert torch.equal(F.sigmoid(x), F.sigmoid(t)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_element_wise_ops() + _run_layer_norm() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_element_wise_ops(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +def run_dist2(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_layer_norm() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1]) +@rerun_if_address_is_in_use() +def test_ln(world_size): + run_func = partial(run_dist2, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +def check_all(): + test_element_wise_ops(2) + + +if __name__ == '__main__': + check_all() diff --git a/tests/test_ops/test_view.py b/tests/test_ops/test_view.py new file mode 100644 index 0000000000000000000000000000000000000000..c48919686d97cc2c1bc4621fb12b4e30d657c464 --- /dev/null +++ b/tests/test_ops/test_view.py @@ -0,0 +1,100 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec +from colossalai.tensor.distspec import DistPlacementPattern +from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print + + +def exam_view_core(pg): + # the case of replicated ColoTensors + x = torch.randn(4, 4).cuda() + x_colo = ColoTensor(x, ColoTensorSpec(pg)) + + y = x.view(2, -1, 2) + y_colo = x_colo.view(2, -1, 2) + + assert torch.all(y == y_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + # the perfect case of col-sliced ColoTensors + split_param_col_tp1d(x_colo, pg) + + z = x.view(torch.Size((2, 1, 2, -1))) + z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) + if dist.get_rank() == 0: + z = z[:, :, :, 0:2] + else: + z = z[:, :, :, 2:] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the perfect case of row-sliced ColoTensors + split_param_row_tp1d(x_colo, pg) + + z = x.view(torch.Size((-1, 2, 2))) + z_colo = x_colo.view(torch.Size((-1, 2, 2))) + if dist.get_rank() == 0: + z = z[0:2, :, :] + else: + z = z[2:, :, :] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the normal case of row-sliced ColoTensors + z = x.view(-1, 2, 2, 2) + z_colo = x_colo.view(-1, 2, 2, 2) + assert torch.all(z == z_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + + +def exam_view_autograd(pg): + x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + y.copy_(x) + y = ColoTensor(y, ColoTensorSpec(pg)) + y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + + xx = x.view(2, 2, -1) + yy_slice = y_slice.view(2, 2, -1) + yy = yy_slice.to_replicate() + grad = torch.randn(2, 2, 4, device=get_current_device()) + + xx.backward(grad) + yy.backward(grad) + assert torch.all(x.grad == y.grad) + + +def exam_view_errors(pg): + x = torch.randn(8, 2, device=get_current_device()) + x = ColoTensor(x, ColoTensorSpec(pg)) + split_param_row_tp1d(x, pg) + + x.view('a', 'b', 'c') + x.view(8, -1) + x.view([-2, -2, -2]) + x.view((-1, -1, -1)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + exam_view_core(pg) + exam_view_autograd(pg) + # exam_view_errors(pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_view(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view(2) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..dff14fbcc5ad2f56612b63d6f69ae7e44951234b --- /dev/null +++ b/tests/test_optimizer/test_cpu_adam.py @@ -0,0 +1,117 @@ +import math + +import torch + +from colossalai.testing import parameterize + + +def torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + param, + grad, + exp_avg, + exp_avg_sq, + use_adamw, +): + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if use_adamw: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +def assertLess(data_diff, threshold, msg): + assert data_diff < threshold, msg + + +def assertTrue(condition, msg): + assert condition, msg + + +@parameterize('adamw', [True, False]) +@parameterize('step', [1, 2]) +@parameterize('p_dtype', [torch.float, torch.half]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_cpu_adam(adamw, step, p_dtype, g_dtype): + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + weight_decay = 0 + + for i in range(1024): + p_data = torch.rand(64, dtype=p_dtype) + p_data_copy = p_data.clone().float() + p_grad = torch.rand(64, dtype=g_dtype) + p_grad_copy = p_grad.clone().float() + exp_avg = torch.rand(p_data.shape) + exp_avg_copy = exp_avg.clone() + exp_avg_sq = torch.rand(p_data.shape) + exp_avg_sq_copy = exp_avg_sq.clone() + + try: + import colossalai._C.cpu_optim + cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) + except: + raise ImportError("Import cpu adam error, please install colossal from source code") + + cpu_adam_op.step( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + p_data.view(-1), # fp32 data + p_grad.view(-1), # fp32 grad + exp_avg.view(-1), + exp_avg_sq.view(-1), + -1, + ) + + torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + p_data_copy, # fp32 data + p_grad_copy, # fp32 grad + exp_avg_copy, + exp_avg_sq_copy, + adamw, + ) + var = p_data_copy - p_data + data_diff = torch.max(torch.abs(var)) + threshold = 1e-3 + assertLess( + data_diff, + threshold, + f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps " + f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}", + ) + max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) + assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") + max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) + assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") + max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) + assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..f7227c2d57c0cbdd973a5cc5e48b74bec21f3608 --- /dev/null +++ b/tests/test_optimizer/test_fused_adam.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.optim.adam import Adam +from torch.optim import AdamW + +from colossalai.nn.optimizer.fused_adam import FusedAdam +from colossalai.testing import parameterize + + +class FC(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Sequential(nn.Linear(64, 64)) + + def forward(self, x): + return self.fc(x) + + +@parameterize('adamw', [False, True]) +@parameterize('p_dtype', [torch.float, torch.half]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_adam(adamw, p_dtype, g_dtype): + model = FC().cuda().to(p_dtype) + state = model.state_dict() + model_copy = FC().cuda().to(p_dtype) + model_copy.load_state_dict(state.copy()) + + if adamw: + optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) + torch_optim = AdamW(model_copy.parameters(), lr=1e-3) + else: + optim = FusedAdam(model.parameters(), lr=1e-3) + torch_optim = Adam(model_copy.parameters(), lr=1e-3) + + data = torch.rand(1024, 64).cuda().to(p_dtype) + data_copy = data.clone() + label = torch.rand(1024, 64).cuda().to(p_dtype) + + for d, l in zip(data, label): + y = model(d) + loss = ((l - y)**2).sum() + optim.zero_grad() + loss.backward() + if p_dtype != g_dtype: + for i in range(len(optim.param_groups[0]['params'])): + optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) + optim.step() + + for d, l in zip(data_copy, label): + y = model_copy(d) + loss = ((l - y)**2).sum() + torch_optim.zero_grad() + loss.backward() + torch_optim.step() + + assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) + + for i in range(len(optim.param_groups[0]['params'])): + if torch.isnan(optim.param_groups[0]['params'][i]).any() \ + or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): + continue + assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..d95a23702fe6eac84e05d31b892f07b480245f67 --- /dev/null +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -0,0 +1,95 @@ +import math + +import torch +import torch.nn as nn +from numpy import dtype + +from colossalai.testing import parameterize +from colossalai.utils import multi_tensor_applier + + +def torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + param, + grad, + exp_avg, + exp_avg_sq, + use_adamw, +): + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if use_adamw: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +@parameterize('adamw', [False, True]) +@parameterize('step', [1, 2]) +@parameterize('p_dtype', [torch.float, torch.half]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_adam(adamw, step, p_dtype, g_dtype): + try: + import colossalai._C.fused_optim + fused_adam = colossalai._C.fused_optim.multi_tensor_adam + dummy_overflow_buf = torch.cuda.IntTensor([0]) + except: + raise ImportError("No colossalai._C.fused_optim kernel installed.") + + count = 0 + + for i in range(1024): + p = torch.rand(64, dtype=p_dtype).cuda() + p_copy = p.clone().float() + g = torch.rand(p.shape, dtype=g_dtype).cuda() + g_copy = g.clone().float() + m = torch.rand(p.shape).cuda() + m_copy = m.clone() + v = torch.rand(p.shape).cuda() + v_copy = v.clone() + + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + weight_decay = 0 + + multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, + True, weight_decay, -1) + + torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + p_copy, # fp32 data + g_copy, # fp32 grad + m_copy, + v_copy, + adamw, + ) + + if torch.isnan(p).any() or torch.isnan(p_copy).any(): + count += 1 + continue + assert count < 200, "too many nans" + assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, + 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..d19192add3fb64e1b65a84e03375728e5ee5c2e7 --- /dev/null +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from torch.optim.adam import Adam +from torch.optim import AdamW + +from colossalai.nn.optimizer.hybrid_adam import HybridAdam +from colossalai.testing import parameterize + +RE = 1024 + + +@parameterize('adamw', [False, True]) +@parameterize('device', ['cpu', 'cuda:0']) +@parameterize('p_dtype', [torch.float]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_adam(adamw, device, p_dtype, g_dtype): + rng_state = torch.get_rng_state() + p = nn.Parameter(torch.rand(64).to(device, p_dtype)) + torch.set_rng_state(rng_state) + p_copy = nn.Parameter(torch.rand(64).to(device).float()) + + if adamw: + optim = HybridAdam([p], lr=1e-3, adamw_mode=True) + torch_optim = AdamW([p_copy], lr=1e-3) + else: + optim = HybridAdam([p], lr=1e-3) + torch_optim = Adam([p_copy], lr=1e-3) + + print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}") + for i in range(RE): + p.grad = torch.rand(64).to(device, p_dtype) + p_copy.grad = p.grad.clone().float() + p.grad.data = p.grad.data.to(g_dtype) + + optim.step() + torch_optim.step() + + if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any(): + continue + assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \ + f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py new file mode 100644 index 0000000000000000000000000000000000000000..243f785adaf9b7ffbb1f130ada07a3fc2955c02a --- /dev/null +++ b/tests/test_optimizer/test_nvme.py @@ -0,0 +1,46 @@ +import pytest +import torch +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.optimizer import CPUAdam, HybridAdam + + +def move_some_params_to_cuda(model, torch_model): + model.embed.weight.data = model.embed.weight.cuda() + torch_model.embed.weight.data = model.embed.weight.cuda() + model.ln1.weight.data = model.ln1.weight.cuda() + torch_model.ln1.weight.data = model.ln1.weight.cuda() + + +def check_params_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' + + +@pytest.mark.parametrize('nvme_offload_fraction', [0.0, 0.5, 1.0]) +@pytest.mark.parametrize('nvme_offload_dir', ['./offload', None]) +@pytest.mark.parametrize('adam_cls', [CPUAdam, HybridAdam]) +def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model = model_builder() + torch_model = model_builder() + move_some_params_to_cuda(model, torch_model) + optimizer = adam_cls(model.parameters(), + lr=0.1, + nvme_offload_fraction=nvme_offload_fraction, + nvme_offload_dir=nvme_offload_dir) + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1) + with torch.no_grad(): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + torch_p.copy_(p) + p.grad = torch.rand_like(p) + torch_p.grad = p.grad + + for _ in range(3): + optimizer.step() + torch_optimizer.step() + check_params_equal(model, torch_model) + + +if __name__ == '__main__': + test_nvme_adam(0.5, './offload', CPUAdam) diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce2cd433b1254746e0d1ff27bf3063c9b88d3a2 --- /dev/null +++ b/tests/test_pipeline/rpc_test_utils.py @@ -0,0 +1,144 @@ +import argparse +import os +import warnings + +import torch +import torch.distributed as dist +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.optim import SGD, Adam, Optimizer, RMSprop + +rpc_is_initialized = _is_current_rpc_agent_set + + +def color_debug(text, prefix=' ', color='blue'): + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + +class MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x.sum() + +class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + self.dag_layer = nn.Linear(dim, dim, bias=False) + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + y = self.dag_layer(y) + return x.sum(), y.sum() + +class RpcTestModel(nn.Module): + + def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: + super().__init__() + self.rank = stage_id + self.is_last_rank = stage_id == actual_stage_num - 1 + self.linear_name = f'linear_{stage_id}' + + if stage_id == 0: + linear = nn.Linear(feat_num, h) + elif stage_id == actual_stage_num - 1: + linear = nn.Linear(h, 1) + else: + linear = nn.Linear(h, h) + + setattr(self, self.linear_name, linear) + + def forward(self, x) -> torch.Tensor: + linear: nn.Module = getattr(self, self.linear_name) + out: torch.Tensor = linear(x) + + if self.is_last_rank: + out = out.sum() + return out + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--epoch', type=int, default=1) + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args() + + +def pg_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=4) + parser.add_argument('--dp_degree', type=int, default=2) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--num_worker_threads', type=str, default=128) + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + return parser.parse_args() + + +def run_worker(rank, args, master_func): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + disable_existing_loggers() + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_pipeline/test_cuda_rpc_chimera.py new file mode 100644 index 0000000000000000000000000000000000000000..45ad8f828e61506649295b88ecd7e2fbfd5dcd3c --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_chimera.py @@ -0,0 +1,80 @@ +import torch +from torch import nn +import torch.autograd as autograd + +from colossalai.pipeline.rpc import ChimeraPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = args.world_size + chunk = 1 + num_microbatches = args.num_microbatches + use_checkpoint = False + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + engine = ChimeraPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + checkpoint=use_checkpoint) + engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) + + input_sample = torch.randn((sample_num, feat_num), device=device) + + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + # input_sample = input_sample[len(input_sample) // 2:] + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + single_result.append(out_val) + for p in test_model.parameters(): + single_result.append(p.grad) + + # print("my") + # print(cuda_rpc_result[1]) + # print("answer:") + # print(single_result[1]) + + # assert len(cuda_rpc_result) == len(single_result) + # for r_c, r_s in zip(cuda_rpc_result, single_result): + # assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..842566730caf2454bdeb6e27b864e58beb3ba7a1 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -0,0 +1,81 @@ +import torch +from torch import nn +from torch import autograd +from torch.optim import SGD, Adam, RMSprop, Optimizer + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + optimizer_class = globals()[args.optimizer] + + lr = 1e-3 + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + engine.initialize_optimizer(optimizer_class, lr=lr) + + _ = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute parameters after updating in cuda rpc + parameters = engine.remote_parameters() + for stage_id in range(actual_stage_num): + for p in parameters[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + optimizer.step() + optimizer.zero_grad() + + for p in test_model.parameters(): + single_result.append(p) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0509555862ea0b1be5ba10c961075af854720a --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_performance.py @@ -0,0 +1,90 @@ +import os +from typing import Callable, List, Optional, Type, Union +import time + +import pytest +import torch +import torch.nn as nn +from titans.dataloader.cifar10 import build_cifar +from torchvision.models import resnet50 +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +from tqdm import tqdm + +from rpc_test_utils import rpc_run, parse_args +import colossalai +import colossalai.nn as col_nn +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from colossalai.context import ParallelMode +from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel +from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine +from colossalai.pipeline.pipeline_process_group import ppg + + +def flatten(x): + return torch.flatten(x, 1) + + +def partition(pp_rank: int, chunk: int, stage_num: int): + pipelinable = PipelinableContext() + + # build model partitions + with pipelinable: + # input : [B, 3, 32, 32] + _ = resnet50() + + pipelinable.policy = "customized" + + exec_seq = [ + 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc' + ] + pipelinable.to_layer_list(exec_seq) + partition = pipelinable.partition(chunk, stage_num, pp_rank) + return partition + + +def run_master(args): + batch_size = args.batch_size + chunk = args.chunk + device = args.device + world_size = args.world_size + stage_num = world_size + num_microbatches = args.num_microbatches + + # build dataloader + root = os.environ.get('DATA', './data') + train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32) + criterion = nn.CrossEntropyLoss() + + pp_engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + criterion=criterion, + checkpoint=False) + + pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) + s = time.time() + + for bx, by in tqdm(train_dataloader): + pp_engine.forward_backward(bx, labels=by, forward_only=False) + + cost_time = time.time() - s + + print("total cost time :", cost_time) + print("cost time per batch:", cost_time / len(train_dataloader)) + + +@pytest.mark.skip("Test for performance, no need for CI") +def main(): + args = parse_args() + # this is due to limitation of partition function + args.world_size = 2 + args.chunk = 1 + rpc_run(args, run_master) + + +if __name__ == '__main__': + main() diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8d03e79813e89976153def8ecc7873af2c913701 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = args.world_size + chunk = args.chunk + num_microbatches = args.num_microbatches + use_checkpoint = args.use_checkpoint + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + for _ in range(epoch): + _ = engine.forward_backward(input_sample, forward_only=False) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..e6713478baecae9115ac3142d5c601d850bec070 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -0,0 +1,73 @@ +import torch +from torch import nn +from torch import autograd + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + single_result.append(out_val) + for p in test_model.parameters(): + single_result.append(p.grad) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py new file mode 100644 index 0000000000000000000000000000000000000000..c4dc617b1683f480f1418ae9f3b974819463e3c1 --- /dev/null +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -0,0 +1,128 @@ +import torch +import pytest +import os +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc + +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware.adaptor import get_fx_topology +from rpc_test_utils import MLP, DAG_MLP +from functools import partial +from colossalai.testing import parameterize, rerun_if_address_is_in_use + +# global variable for model created +batch_size = 16 +dim = 10 +rpc_is_initialized = _is_current_rpc_agent_set + +def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): + model.eval() + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = balanced_split_pass(gm, stage_num) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank+1] + +def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return partition + +def run_master(model_cls, world_size, forward_only): + torch.manual_seed(100) + + epoch = 3 + device = 'cuda' + stage_num = world_size + chunk = 1 + num_microbatches = 8 + use_checkpoint = 'store_true' + + if model_cls == MLP: + def data_gen(): + x = torch.zeros((batch_size, dim)) + kwargs = dict(x=x) + return kwargs + model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 + elif model_cls == DAG_MLP: + def data_gen(): + x = torch.zeros((batch_size, dim)) + y = torch.zeros((batch_size, dim)) + kwargs = dict(x=x, y=y) + return kwargs + model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 + else: + pass + + data_kwargs = data_gen() + + engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint,) + if not forward_only: + engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) + + for _ in range(epoch): + input_x = torch.randn((batch_size, dim), device=device) + input_y = torch.randn((batch_size, dim), device=device) + logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) + +def run_worker(rank, model_cls, world_size, forward_only, master_func): + master_addr = 'localhost' + master_port = 29020 + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(master_port) + + disable_existing_loggers() + + launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=1, + tp_degree=1, + num_worker_threads=128, + device='cuda') + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(model_cls, world_size, forward_only) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + +@pytest.mark.skip("skip due to CI torch version 1.11") +@parameterize('model_cls', [MLP, DAG_MLP]) +@parameterize('forward_only', [True, False]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp_middleware_fwd(model_cls, forward_only): + world_size = 4 + master_func = run_master + mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) + +if __name__ == "__main__": + test_pp_middleware_fwd() \ No newline at end of file diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_pipeline/test_pipelinable.py new file mode 100644 index 0000000000000000000000000000000000000000..c99a88550b71d4b53a7ef06a361a0e339e180052 --- /dev/null +++ b/tests/test_pipeline/test_pipelinable.py @@ -0,0 +1,59 @@ +import torch +import torch.multiprocessing as mp + +from colossalai.pipeline.pipelinable import PipelinableContext + +from colossalai.testing import rerun_on_exception + +NUM_CHUNKS = 1 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = torch.nn.Linear(dim, intermediate_dim) + self.activation = torch.nn.GELU() + self.dense_2 = torch.nn.Linear(intermediate_dim, dim) + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + x = self.dropout(x) + return x + + +def run_pipelinable(rank): + pipelinable = PipelinableContext() + with pipelinable: + model = MLP() + + assert pipelinable.policy == "balanced" + pipelinable.policy = "uniform" + assert pipelinable.policy == "uniform" + pipelinable.to_layer_list() + + assert pipelinable.layers_count == len(list(model.children())) + + pipeline_model_part_0 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 0) + assert isinstance(pipeline_model_part_0, torch.nn.Module) + pipeline_model_part_1 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 1) + assert isinstance(pipeline_model_part_1, torch.nn.Module) + + layers_count_in_part_0 = len(list(pipeline_model_part_0._module_list)) + layers_count_in_part_1 = len(list(pipeline_model_part_1._module_list)) + + assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count + + +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_pipelinable(): + mp.spawn(run_pipelinable, nprocs=1) + + +if __name__ == '__main__': + test_pipelinable() diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..c67e4175df92b5b46b99efd7679433c7391fa3ba --- /dev/null +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -0,0 +1,43 @@ +import os + +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +import pytest + +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from rpc_test_utils import pg_parse_args, rpc_is_initialized + + +def run_worker(rank, args): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + disable_existing_loggers() + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + + if rpc_is_initialized(): + rpc.shutdown() + + +if __name__ == "__main__": + args = pg_parse_args() + world_size = args.world_size + mp.spawn(run_worker, args=(args,), nprocs=world_size) \ No newline at end of file diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_tensor/common_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a35d02ce5edd0e2f4a16831621041009928f129 --- /dev/null +++ b/tests/test_tensor/common_utils/__init__.py @@ -0,0 +1 @@ +from ._utils import * diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5d06622bebf5eec39df136f22977a4bc5ec472 --- /dev/null +++ b/tests/test_tensor/common_utils/_utils.py @@ -0,0 +1,81 @@ +import os +import random +import numpy as np +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True + + +def replace_parameter_add_grad(layer, weight=None, bias=None): + if weight is not None: + delattr(layer, 'weight') + setattr(layer, 'weight', weight) + layer.weight.requires_grad = True + if bias is not None: + delattr(layer, 'bias') + setattr(layer, 'bias', bias) + layer.bias.requires_grad = True + + +def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): + dist.broadcast(tensor, src=0) + tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] + return tensor_chunk.clone() + + +def tensor_equal(A, B): + return torch.allclose(A, B, rtol=1e-3, atol=1e-1) + + +def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + if world_size is None: + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + if rank is None: + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + else: + raise NotImplementedError + + +def split_param_single_dim_tp1d(dim, param, pg): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + if param.process_group.tp_world_size() == 1: + param.set_process_group(pg) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param, pg): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param, pg): + split_param_single_dim_tp1d(-1, param, pg) + + +def debug_print(ranks, *args): + if dist.get_rank() in ranks: + print(*args) + dist.barrier() diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_tensor/core/test_dist_spec_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..e02f4e7977f63fd35e142eef0e8b92632391e516 --- /dev/null +++ b/tests/test_tensor/core/test_dist_spec_mgr.py @@ -0,0 +1,66 @@ +import math +import torch +import torch.distributed as dist +import pytest +import colossalai +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec +from functools import partial + + +def run(): + group = ProcessGroup(tp_degree=dist.get_world_size()) + rank = dist.get_rank() + size = dist.get_world_size() + depth = int(math.sqrt(size)) + assert depth == math.sqrt(size) + x = torch.rand(8, 8).cuda() + old_dist_spec = ReplicaSpec() + row_spec = ShardSpec([0], [size]) + col_spec = ShardSpec([-1], [size]) + mat_spec = ShardSpec([0, 1], [depth, depth]) + row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group) + assert torch.equal(x.chunk(size, 0)[rank], row_shard) + assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group)) + col_shard = DistSpecManager._all_to_all(row_shard, row_spec, col_spec, group) + assert torch.equal(x.chunk(size, -1)[rank], col_shard) + assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec, group)) + mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec, group) + assert torch.equal(x.chunk(depth, 0)[rank // depth].chunk(depth, 1)[rank % depth], mat_shard) + assert torch.equal(x, DistSpecManager._gather(mat_shard, mat_spec, group)) + + +def check_mem(): + pg = ProcessGroup(tp_degree=dist.get_world_size()) + size = dist.get_world_size() + assert torch.cuda.memory_allocated() == 0 + x = torch.rand(32, 32).cuda() + orig_mem = x.numel() * x.element_size() + assert torch.cuda.memory_allocated() == orig_mem + old_dist_spec = ReplicaSpec() + row_spec = ShardSpec([0], [size]) + x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg) + assert x.size(0) == 32 // size and x.size(1) == 32 + assert torch.cuda.memory_allocated() == orig_mem // size + x.data = DistSpecManager._gather(x, row_spec, pg) + assert torch.cuda.memory_allocated() == orig_mem + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_mem() + run() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_dist_spec_mgr(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_dist_spec_mgr(4) diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..b48d9e9a2dfa4c937bcb2edcf7eba25089390ff9 --- /dev/null +++ b/tests/test_tensor/core/test_tensor.py @@ -0,0 +1,160 @@ +import torch +import pytest +from colossalai.tensor import ColoTensor +from numpy import allclose + +import colossalai +from colossalai.utils import free_port +from colossalai.tensor import ColoTensorSpec +from colossalai.core import global_context as gpc +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec +from functools import partial + + +def _run_tensor_indexing(): + pg = ProcessGroup() + torch_t = torch.randn(2, 3) + colo_t = ColoTensor(torch_t, ColoTensorSpec(pg)) + assert allclose(torch_t[:, 1], colo_t[:, 1]) + + +def _run_wrapped_tensor_func(): + pg = ProcessGroup() + t_ref = torch.randn(4, 5) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + + # non-func attr + assert t.is_cuda == t_ref.is_cuda + + # return 1 torch.Tensor + t_abs = t.abs() + assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs()) + + # return 1 non-torch.Tensor + assert t.dim() == t_ref.dim() + + # return >1 torch.Tensor + assert isinstance(t, ColoTensor) + t_split1, t_split2 = t.split(2) + assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}" + + +def _run_operand(world_size): + pg = ProcessGroup() + t_ref = torch.randn(4, 5) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + + t_ref_res = t_ref + t_ref + t_res = t + t + + assert isinstance(t_res, ColoTensor) + assert torch.allclose(t_ref_res, t_res) + + pg = ProcessGroup(tp_degree=world_size) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + t.set_dist_spec(ShardSpec([0], [world_size])) + t_new = torch.zeros_like(t) + assert isinstance(t_new, ColoTensor) + assert t_new.is_sharded() + + +#### Test Distributed init a Colotensor + + +def _run_view(world_size): + t_ref = torch.randn(4, 5) + rank = gpc.get_global_rank() + pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) + t = ColoTensor.from_torch_tensor( + t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]))) + + assert t.size_global()[0] == 4 * world_size + assert t.size_global(1) == 5 + assert t.size_global() == torch.Size([4 * world_size, 5]) + + t = t.view(4 * 5 * world_size) + assert t.shape == torch.Size([4 * 5 * world_size]) + + +def _run_tensor_shard_init(world_size): + t_ref = torch.randn(4, 5) + pg = ProcessGroup(tp_degree=world_size) + shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]) + tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) + t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) + t.set_dist_spec(ReplicaSpec()) + + assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})" + + +def _run_tensor_replicated_init(world_size): + t_ref = torch.randn(4 * world_size, 5) + pg = ProcessGroup() + spec = ColoTensorSpec(pg) + t = ColoTensor.from_torch_tensor(t_ref.clone(), spec) + + assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}" + + +def _run_process_group(world_size): + pg1 = ProcessGroup() + pg2 = ProcessGroup() + assert pg1 == pg2 + + +def _run_redistributed(world_size): + if world_size != 4: + return + pg1 = ProcessGroup(tp_degree=2, dp_degree=2) + pg2 = ProcessGroup(tp_degree=4, dp_degree=1) + + spec1 = ColoTensorSpec(pg1) + t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) + t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()])) + assert t1.is_sharded() + t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2) + assert t1.is_sharded() + pg3 = ProcessGroup(tp_degree=1, dp_degree=4) + t1 = t1.redistribute(ReplicaSpec(), pg3) + assert t1.is_replicate() + + +def _run_set_tensor_spec(world_size): + if world_size != 4: + return + pg = ProcessGroup(tp_degree=2, dp_degree=2) + spec1 = ColoTensorSpec(pg) + t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) + + dist_spec2 = ShardSpec([-1], [pg.tp_world_size()]) + assert t1.is_replicate() + t1.set_dist_spec(dist_spec2) + assert t1.is_shard_1dcol() + + +def run_dist_tests(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_tensor_shard_init(world_size) + _run_tensor_replicated_init(world_size) + _run_view(world_size) + _run_process_group(world_size) + _run_tensor_indexing() + _run_operand(world_size) + _run_wrapped_tensor_func() + _run_redistributed(world_size) + _run_set_tensor_spec(world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_dist_cases(world_size): + run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_dist_cases(4) diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..ad8ac87b2e1ede2b1ee8c3be4f1964729fccc152 --- /dev/null +++ b/tests/test_tensor/model/test_gpt2.py @@ -0,0 +1,153 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import ( + debug_print, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_equal, + tensor_shard_equal, +) + + +def init_1d_row_spec(model, pg: ProcessGroup): + tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'weight' in n and 'ln' not in n: + p.set_tensor_spec(*tensor_spec) + + +def init_1d_col_spec(model, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'ln' not in n and ('weight' in n or 'bias' in n): + p.set_tensor_spec(*spec) + + +def init_megatron_spec(model, pg: ProcessGroup): + for mn, module in model.named_modules(): + # debug_print([0], mn) + for pn, param in module.named_parameters(recurse=False): + # debug_print([0], '\t', pn, param.compute_spec, param.shape) + param.set_process_group(pg) + + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) + param.compute_spec.set_output_replicate(False) + else: + raise RuntimeError + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) + else: + assert 'bias' in pn + elif 'wte' in mn or 'wpe' in mn: + assert 'weight' in pn + split_param_col_tp1d(param, pg) + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) + # debug_print([0], '\t', param.compute_spec, param.shape) + + +def check_param_equal(model, torch_model, pg: ProcessGroup): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1" + assert pg.tp_world_size() is not None + assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size()) + + +def check_grad_equal(model, torch_model, pg: ProcessGroup): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_gpt(init_spec_func, use_ddp): + world_size = torch.distributed.get_world_size() + + # build a PG with TP and DP hybrid + pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1)) + + # set seed make processes of the same tp group use the same seed + # set_seed(pg.tp_local_rank()) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # make sure torch_model and model has the same parameter values + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + + if use_ddp: + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + model = ColoDDP(model, process_group=pg) + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + + init_spec_func(model, pg) + + check_param_equal(model, torch_model, pg) + + # close the dropout in eval mode + model.eval() + torch_model.eval() + set_seed(pg.dp_local_rank()) + torch.distributed.barrier() + for i, (input_ids, label) in enumerate(train_dataloader): + colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) + logits = model(colo_input) + torch_logits = torch_model(input_ids) + assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" + loss = criterion(logits, input_ids) + torch_loss = criterion(torch_logits, input_ids) + if use_ddp: + model.backward(loss) + else: + loss.backward() + torch_loss.backward() + check_grad_equal(model, torch_model, pg) + if i > 0: + break + set_seed(313) + + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # Comments below tests for speed concern + # run_gpt(init_1d_row_spec, use_ddp) + # run_gpt(init_1d_col_spec, use_ddp) + run_gpt(init_megatron_spec, use_ddp) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +def test_gpt(world_size, use_ddp): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(4, use_ddp=False) diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3f53b94e0642c4edca34101ff1e89b824f0d84bd --- /dev/null +++ b/tests/test_tensor/model/test_model.py @@ -0,0 +1,340 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor, ProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import ( + check_equal, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_shard_equal, +) + + +def run_1d_hybrid_tp(model_name): + # A simple net with two stacked nn.Linear + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + + optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1)) + + # Make two models have the same init params + for p1, p2 in zip(model.parameters(), model_torch.parameters()): + p2.data.copy_(p1.data) + else: + model_torch = None + optimizer_torch = None + + pg = ProcessGroup(tp_degree=world_size) + if 'bert' == model_name: + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + # num_class = vocab_size = 30524 | (30524, 8) + elif 'word_embeddings' in name and 'weight' in name: + split_param_row_tp1d(p, pg) + # num_class = seq_len = 512 | (512, 8) + elif 'position_embeddings' in name and 'weight' in name: + split_param_row_tp1d(p, pg) + # num_class = type_vocab_size = 2 | (2, 8) + elif 'token_type_embeddings' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + + elif "simple_net" == model_name: + # A naive way to set spec for all weights in Linear + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + if 'embed' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + if 'proj1' in name and ('weight' in name or 'bias' in name): + split_param_row_tp1d(p, pg) + if 'proj2' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + if 'classifier' in name and ('weight' in name or 'bias' in name): + split_param_row_tp1d(p, pg) + + model = model.cuda() + model.eval() + if rank == 0: + model_torch.eval() + + colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) + + for i, (data, label) in enumerate(train_dataloader): + + # Zero grad + colo_optimizer.zero_grad() + if rank == 0: + optimizer_torch.zero_grad() + torch.distributed.barrier() + + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # Test output + if rank == 0: + if criterion: + output_torch = model_torch(data) + loss_torch = criterion(output_torch, label) + else: + output_torch = model_torch(data, label) + loss_torch = output_torch + assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed" + torch.distributed.barrier() + + loss.backward() + colo_optimizer.step() + + if rank == 0: + loss_torch.backward() + optimizer_torch.step() + + with torch.no_grad(): + # check param + for p, torch_p in zip(model.parameters(), model_torch.parameters()): + assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size()) + torch.distributed.barrier() + if i > 5: + break + + +# Test the overrided parameters() and named_parameters() member functions +def test_model_parameters(): + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build a module with 2 Linear, 4 parameters in total. + class Net(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2)) + self.extra_param = torch.nn.Parameter(torch.randn(2)) + + with ColoInitContext(device=get_current_device()): + model = Net() + + param_cnt = 0 + for name, p in model.named_parameters(): + param_cnt += 1 + assert param_cnt == 5 + + for name, colo_p in model.named_parameters(): + assert colo_p.is_model_data() + + param_cnt = 0 + for name, p in model.named_parameters(recurse=False): + param_cnt += 1 + assert param_cnt == 1 + + param_cnt = 0 + for p in model.fcs[0].parameters(recurse=False): + param_cnt += 1 + assert param_cnt == 2 + + +def test_colo_optimizer(): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) + for i, (data, label) in enumerate(train_dataloader): + colo_optimizer.zero_grad() + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + loss.backward() + colo_optimizer.step() + + if i > 5: + break + + +def run_1d_row_tp(model_name: str): + # A simple net with two stacked nn.Linear + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + rank = torch.distributed.get_rank() + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + + set_seed(1) + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + + # A naive way to set spec for all weights in Linear + for mo_name, module in model.named_modules(): + # print(mo_name) + for pa_name, param in module.named_parameters(recurse=False): + # print('\t', pa_name, param.shape) + if not isinstance(param, ColoTensor): + continue + if 'weight' in pa_name: + if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name: + split_param_row_tp1d(param, pg) + elif 'LayerNorm' not in mo_name and 'ln' not in mo_name: + split_param_col_tp1d(param, pg) + + model = model.cuda() + + for i, (data, label) in enumerate(train_dataloader): + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # For reference + if rank == 0: + if criterion: + output_torch = model_torch(data) + loss_torch = criterion(output_torch, label) + else: + output_torch = model_torch(data, label) + loss_torch = output_torch + assert torch.allclose(loss, loss_torch, rtol=1e-2) + torch.distributed.barrier() + + loss.backward() + + if rank == 0: + loss_torch.backward() + torch.distributed.barrier() + + if i > 5: + break + + +def _run_pretrain_load(): + from transformers import BertForMaskedLM + set_seed(1) + model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') + with ColoInitContext(device=get_current_device()): + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + + model_pretrained = model_pretrained.cuda() + model = model.cuda() + + dict_pretrained = {} + dict_col = {} + c_ref = 0 + for name, param in model_pretrained.named_parameters(): + dict_pretrained[name] = param + c_ref += 1 + c1 = 0 + c2 = 0 + for name, param in model.named_parameters(): + if isinstance(param, ColoParameter): + c1 += 1 + else: + c2 += 1 + dict_col[name] = param + assert c_ref == c1 + assert c2 == 0 + if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias: + assert model.cls.predictions.decoder.bias is model.cls.predictions.bias + + for name, param in dict_pretrained.items(): + check_equal(param, dict_col[name]) + + +def run_model_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # Comment below test for speed consideration + # for name in ['bert', 'simple_net']: + # run_1d_row_tp(name) + for name in ['bert', 'simple_net']: + run_1d_hybrid_tp(name) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_model(world_size): + run_func = partial(run_model_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +def run_pretrain_load_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_pretrain_load() + + +# The test case has to download huggingface pretrained models from the internet +# So we manually trigger the test. +@pytest.mark.skip +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_pretrain_load(world_size): + run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + # test_model_parameters() + # test_colo_optgimizer() + test_model(4) + # test_pretrain_load(4) diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..997b416f12c3ce3219e4156d479b063d45cf1514 --- /dev/null +++ b/tests/test_tensor/model/test_module_spec.py @@ -0,0 +1,233 @@ +from copy import deepcopy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.parallel.layers import check_colo_module, init_colo_module +from colossalai.tensor import ( + ColoTensor, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, + distspec, +) +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal + + +def run_model_with_spec(mode, model_name): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + rank = pg.rank() + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=False) + + if rank == 0: + model_seq = model_builder(checkpoint=False) + model_seq = model_seq.cuda() + + # Make two models have the same init params + for p1, p2 in zip(model.parameters(), model_seq.parameters()): + p2.data.copy_(p1.data) + + compute_spec = ComputeSpec(ComputePattern.TP1D) + # Not all layers in Bert can be mod by 4. + # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. + if 'bert' == model_name: + if 'col' == mode: + init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode) + init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode) + init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row') + elif 'row' == mode: + init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col') + init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode) + init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode) + elif 'simple_net' == model_name: + init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) + + model = model.cuda() + for i, (data, label) in enumerate(train_dataloader): + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # For reference + if rank == 0: + if criterion: + output_seq = model_seq(data) + loss_seq = criterion(output_seq, label) + else: + output_seq = model_seq(data, label) + loss_seq = output_seq + + if rank == 0: + with torch.no_grad(): + assert torch.allclose(loss, loss_seq, rtol=1e-2) + + loss.backward() + + if rank == 0: + loss_seq.backward() + + with torch.no_grad(): + # check param + for p1, p2 in zip(model.parameters(), model_seq.parameters()): + if p1.size() == p2.size(): + assert torch.allclose(p1, p2) + else: + if p1.size(-1) < p2.size(-1): # col + world_size = p2.size(-1) // p1.size(-1) + split_p2 = torch.chunk(p2, world_size, dim=-1)[0] + + elif p1.size(0) < p2.size(0): # row + world_size = p2.size(0) // p1.size(0) + split_p2 = torch.chunk(p2, world_size, dim=0)[0] + + assert torch.allclose(p1, split_p2) + + if i > 3: + break + + +def run_linear_with_spec(mode): + with ColoInitContext(device=get_current_device()): + model = torch.nn.Linear(4, 8) + + model_handy = deepcopy(model) + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + compute_spec = ComputeSpec(ComputePattern.TP1D) + init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) + + x = torch.rand(2, 4).cuda() + colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg)) + + out = model(x) + colo_out = model_handy(colo_x) + assert tensor_equal(out, colo_out) + + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + + assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_check_shared_param(): + from transformers import BertConfig, BertForMaskedLM + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 24 + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + rank = pg.rank() + + config = BertConfig(vocab_size=vocab_size, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + with ColoInitContext(device=get_current_device()): + model = BertForMaskedLM(config) + + model = model.cuda() + compute_spec = ComputeSpec(ComputePattern.TP1D) + # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec + assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 + # They are all Linear, so both row is allowed. This should pass check. + init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row') + # This should be detected by check because you can not set weight as row while set bias as col. + col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + + # TODO(jiaruifang) optimize this line + if not model.cls.predictions.bias.has_initialized: + model.cls.predictions.bias.pg = pg + model.cls.predictions.bias.dist_spec = ReplicaSpec() + model.cls.predictions.bias.has_initialized = True + model.cls.predictions.bias.set_tensor_spec(*col_spec) + try: + check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False) + except Exception as e: + assert 'incorrectly sharded' in str(e) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_linear_with_spec('col') + run_linear_with_spec('row') + + +def run_dist_model(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + for model_name in ['simple_net', 'bert']: + run_model_with_spec('col', model_name) + run_model_with_spec('row', model_name) + + +def run_dist_check(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_check_shared_param() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.skip("for higher testing speed") +@rerun_if_address_is_in_use() +def test_module_linear_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.skip("for higher testing speed") +@rerun_if_address_is_in_use() +def test_module_model(world_size): + run_func = partial(run_dist_model, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.skip("for higher testing speed") +@rerun_if_address_is_in_use() +def test_module_check(world_size): + run_func = partial(run_dist_check, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_module_linear_1d(4) diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..5f440ae79b838d1d4ca1686411c06734fa42e68d --- /dev/null +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -0,0 +1,47 @@ +import torch +import pytest +from functools import partial + +import torch.multiprocessing as mp +import torch.distributed as dist + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from tests.test_tensor.common_utils import tensor_shard_equal + + +def run_dist(rank, world_size, port, dp_degree, tp_degree): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) + x = torch.randn(4, 4) + param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) + spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) + param.set_tensor_spec(*spec) + + gather_tensor(param) + if dist.get_rank() == 0: + assert torch.all(x == param) + else: + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + dist.barrier() + + scatter_tensor(param, spec[0]) + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + assert param.requires_grad is True + dist.barrier() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_checkpoint(world_size=4) diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..46eee61f1ecf2e4979ab75f8c43ebbfeea820fe3 --- /dev/null +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -0,0 +1,226 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed import ReduceOp + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_all_gather(device_mesh, rank): + # tensor to comm + if rank in (0, 2): + sharded_tensor_to_comm = torch.ones(2, 2).cuda() + else: + sharded_tensor_to_comm = torch.zeros(2, 2).cuda() + + # tensor to check + tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda() + + # test all gather + dim_partition_dict = {1: [1]} + + # DistSpec: + # shard_sequence: R,S1 + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) + + assert sharded_tensor_to_comm.equal(tensor_to_check) + + +def check_shard(device_mesh, rank): + # tensor to comm + sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda() + sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda() + # tensor([[0., 0., 1., 1.], + # [0., 0., 1., 1.]]) + tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1) + + # test shard + dim_partition_dict = {} + + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1) + tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard) + + if rank in (0, 2): + assert tensor_to_shard.equal(sharded_tensor_to_comm_0) + if rank in (1, 3): + assert tensor_to_shard.equal(sharded_tensor_to_comm_1) + + +def check_all_to_all(device_mesh, rank): + # tensor to comm + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + if rank in (0, 1): + # tensor([[0.], + # [0.], + # [2.], + # [2.]]) + tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (2, 3): + # tensor([[1.], + # [1.], + # [3.], + # [3.]]) + tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() + + # test shard + dim_partition_dict = {0: [0]} + + # DistSpec: + # shard_sequence: S0,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, + sharding_spec, + gather_dim=0, + shard_dim=1, + logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_fwd(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 + # tensor to check + if rank in (0, 2): + # tensor([[2., 2.], + # [2., 2.]]) + tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (1, 3): + # tensor([[4., 4.], + # [4., 4.]]) + tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda() + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_bwd(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + tensor_to_check = torch.ones(2, 2).cuda() * rank + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 at flatten device mesh + # tensor to check + # tensor([[6., 6.], + # [6., 6.]]) + tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1]) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + # test all gather + check_all_gather(device_mesh, rank) + + # test shard + check_shard(device_mesh, rank) + + # test all to all + check_all_to_all(device_mesh, rank) + + # test all reduce + check_all_reduce_fwd(device_mesh, rank) + check_all_reduce_bwd(device_mesh, rank) + + # test all reduce in 1D flatten device mesh + check_all_reduce_in_flatten_device_mesh(device_mesh, rank) + gpc.destroy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm_spec(): + world_size = 4 + run_func = partial(check_comm, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_comm_spec() diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py new file mode 100644 index 0000000000000000000000000000000000000000..2f7aebed5bc434384235ebf94270b9702284d89c --- /dev/null +++ b/tests/test_tensor/test_context.py @@ -0,0 +1,69 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.tensor import ( + ColoParameter, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, +) +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def run_colo_init_context(rank: int, world_size: int, port: int): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated. + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # keep parameters replicated during init + with ColoInitContext(device=get_current_device()): + model1 = model_builder() + + # shard the parameters during init + set_seed(42) + shard_spec = ReplicaSpec() + + # If using ShardSpec, the assertations will failed. + # But it is not a bug, the initialized values are not consist with the original one. + # shard_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + default_pg = ProcessGroup(tp_degree=world_size) + with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec): + model2 = model_builder() + + # reshard both models + new_shard = ShardSpec(dims=[-1], num_partitions=[world_size]) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + p1: ColoParameter = p1 + p1.set_process_group(ProcessGroup(tp_degree=world_size)) + p1.set_dist_spec(new_shard) + p2.set_dist_spec(new_shard) + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert (torch.allclose(p1, p2)) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_colo_init_context(world_size): + run_func = partial(run_colo_init_context, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_colo_init_context(2) diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ab30601501b0ea60507bc9af5b62eb3b8ae3fb --- /dev/null +++ b/tests/test_tensor/test_mix_gather.py @@ -0,0 +1,333 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import mix_gather_simulator +from colossalai.utils import free_port + + +def check_mix_gather_S0S1(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0]) + b_target_pair = (b, [1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [4, 2] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S0S1(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [1]} + # DistSpec: + # shard_sequence: R,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S1S0(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [1]) + b_target_pair = (b, [0]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [2, 4] + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S1S0(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S01R(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0, 1]) + b_target_pair = (b, []) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S01R(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + + # DistSpec: + # shard_sequence: S01, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {0: [0]} + + # DistSpec: + # shard_sequence: S1, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_RS01(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + + (f, b) = (0, 1) + f_target_pair = (f, []) + b_target_pair = (b, [0, 1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_RS01(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + + # DistSpec: + # shard_sequence: R, S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 8) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True, need_flatten=True) + + check_mix_gather_S0S1(device_mesh, rank) + + check_two_all_gather_S0S1(device_mesh, rank) + + check_mix_gather_S1S0(device_mesh, rank) + + check_two_all_gather_S1S0(device_mesh, rank) + + check_mix_gather_S01R(device_mesh, rank) + + check_two_all_gather_S01R(device_mesh, rank) + + check_mix_gather_RS01(device_mesh, rank) + + check_two_all_gather_RS01(device_mesh, rank) + + +@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs") +def test_mix_gather(): + world_size = 8 + run_func = partial(check_comm, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mix_gather() diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3c4b2132e4ab21c03900175e5c1cf09deebf1b --- /dev/null +++ b/tests/test_tensor/test_parameter.py @@ -0,0 +1,33 @@ +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup +import torch +import pytest +from common_utils import tensor_equal +import colossalai +from colossalai.utils import free_port + + +@pytest.mark.skip +def test_multiinheritance(): + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colo_param = ColoParameter(None, requires_grad=True) + assert colo_param.dist_spec.placement.value == 'r' + assert isinstance(colo_param, ColoTensor) + assert isinstance(colo_param, torch.nn.Parameter) + + # __deepcopy__ overload + import copy + colo_param2 = copy.deepcopy(colo_param) + assert isinstance(colo_param2, ColoParameter) + assert tensor_equal(colo_param.data, colo_param2.data) + assert colo_param.requires_grad == colo_param2.requires_grad + + # __repr__ overload + assert 'ColoParameter' in str(colo_param) + + # __torch_function__ + clone_param = torch.clone(colo_param) + assert isinstance(clone_param, ColoTensor) + + +if __name__ == '__main__': + test_multiinheritance() diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe9ee292cd014ae37d41ec5df4f7fcd6ccb2d99 --- /dev/null +++ b/tests/test_tensor/test_shape_consistency.py @@ -0,0 +1,144 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern +import torch +from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec +from colossalai.device.device_mesh import DeviceMesh + +physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +mesh_shape = (4, 4) +# [[0, 1, 2, 3], +# [4, 5, 6, 7], +# [8, 9, 10,11], +# [12,13,14,15]] +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) +entire_shape = torch.Size((64, 32, 16)) +shape_consistency_manager = ShapeConsistencyManager() + + +def test_one_step_transform(): + + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + + # {DistSpec: + # shard_sequence: R,S1,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, { + "forward": 0, + "backward": 0, + "total": 0 + }) + + assert '[R, S1, R]' in [ + str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() + ] + assert '[S0, R, R]' in [ + str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() + ] + + dim_partition_dict_all2all = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all) + # {DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec: + # shard_sequence: R,S1,S0 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, { + "forward": 0, + "backward": 0, + "total": 0 + }) + + assert '[S01, R, R]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + assert '[R, S1, S0]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + assert '[S0, R, S1]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + + dim_partition_shard = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard) + # {DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} + rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, { + "forward": 0, + "backward": 0, + "total": 0 + }) + + assert '[S01, R, R]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + assert '[S0, S1, R]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + assert '[S0, R, S1]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + + +def test_shape_consistency(): + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + sharding_spec_source, sharding_spec_target) + + transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) + assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + + # all-gather(S01) -> S0 + assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + assert comm_action_sequence[0].gather_dim == 1 + assert comm_action_sequence[0].logical_process_axis == 1 + + # all-to-all(R, S0) -> [S0, R] + assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + assert comm_action_sequence[1].gather_dim == 1 + assert comm_action_sequence[1].shard_dim == 0 + assert comm_action_sequence[1].logical_process_axis == 0 + + # shard(S0) -> [S01] + assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + assert comm_action_sequence[2].shard_dim == 0 + assert comm_action_sequence[2].logical_process_axis == 1 + + assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', + '[S01, R, R]')][0] == transform_path + assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', + '[S01, R, R]')][1] == comm_action_sequence + + +if __name__ == '__main__': + test_one_step_transform() + test_shape_consistency() diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..4c838bc83fad1d4fbcb03de2b553efac28f6c048 --- /dev/null +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -0,0 +1,81 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_apply(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = torch.Size((4, 2)) + shape_consistency_manager = ShapeConsistencyManager() + dim_partition_source = {0: [0]} + dim_partition_target = {1: [0]} + + # DistSpec: + # shard_sequence: S0,R + # device_mesh_shape: (2, 2) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 2) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + if rank in (0, 1): + # tensor([[0.], + # [0.], + # [2.], + # [2.]]) + tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (2, 3): + # tensor([[1.], + # [1.], + # [3.], + # [3.]]) + tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() + + tensor_to_comm.sharding_spec = sharding_spec_source + tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) + assert tensor_to_comm.equal(tensor_to_check) + assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_apply(): + world_size = 4 + run_func = partial(check_apply, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_apply() diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..85008c67a9c269f64137049f3907cd586a5a32df --- /dev/null +++ b/tests/test_tensor/test_sharded_linear.py @@ -0,0 +1,237 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.nn._ops._utils import gather_forward_split_backward +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # create mlp vars + x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda() + w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda() + b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda() + + # run normal forward + out = F.linear(x, w, b) + + # create mesh meta + # the mesh is in the following topo + # [[0, 1], + # [2, 3]] + physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + row_id = rank // 2 + column_id = rank % 2 + + # create pg + row_process_group = None + col_process_group = None + row_to_ranks = {0: [0, 1], 1: [2, 3]} + col_to_ranks = {0: [0, 2], 1: [1, 3]} + + for idx in range(2): + # row ranks + row_ranks = row_to_ranks[idx] + row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2) + + # col ranks + col_ranks = col_to_ranks[idx] + col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2) + + if rank in row_ranks: + row_process_group = row_pg + + if rank in col_ranks: + col_process_group = col_pg + + ######################## + # RRR x RS0 -> RRS0 # + ######################## + # w will be transposed in F.linear + x_replica = x.detach().clone() + w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id] + b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id] + + # adding sharding spec + x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]}) + + # check sharding spec + assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_replica, w_shard, b_shard) + assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # S0RR x RS1 -> S0RS1 # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id] + w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id] + b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id] + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_shard) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] + expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # S0RS1 x S1R -> S0RR # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id] + x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RRS0 x S0R -> RRR # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = out + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RS0S1 x S1R -> RS0R # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id] + x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RRS0 x S0S1 -> RRS1 # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] + w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id] + b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id] + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_shard) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id] + assert torch.allclose(out_shard, expected_out_shard) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_sharded_mlp(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_mlp(4) diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..909c84ef0f0ebeadcaab09c8fa23dd0a90199843 --- /dev/null +++ b/tests/test_tensor/test_sharding_spec.py @@ -0,0 +1,25 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec + + +def test_sharding_spec(): + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((16, 8, 6)) + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + assert str(sharding_spec.sharding_sequence) == "[S01, R, R]" + + +if __name__ == '__main__': + test_sharding_spec() diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..33db676cb85f7b7b4fdbf913e2c8cd60b11ceaa5 --- /dev/null +++ b/tests/test_tensor/test_tp_with_zero.py @@ -0,0 +1,152 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_shard_equal +from tests.test_tensor.model.test_gpt2 import init_megatron_spec + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \ + "parameter '{}' has problem.".format(key) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids): + optimizer.zero_grad() + logits = model(input_ids) + logits = logits.float() + loss = criterion(logits, input_ids) + optimizer.backward(loss) + return logits + + +def init_1d_row_spec(model, pg: ProcessGroup): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'weight' in n and 'ln' not in n: + p.set_tensor_spec(*spec) + + +def init_1d_col_spec(model, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'ln' not in n and ('weight' in n or 'bias' in n): + p.set_tensor_spec(*spec) + + +@parameterize('placement_policy', ['cuda', 'cpu']) +def run_gpt(placement_policy, tp_init_spec_func=None): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + + # world size, dp = 2, tp =2, construct a hybrid parallelism. + if world_size == 4: + pg = ProcessGroup(tp_degree=2) + else: + pg = ProcessGroup(tp_degree=world_size) + + if tp_init_spec_func: + tp_init_spec_func(model, pg) + + dp_world_size = pg.dp_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[dp_world_size]['chunk_size'] = 5000 + config_dict[dp_world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + + model = GeminiDDP(model, init_device, placement_policy, True, False, 32) + # The same as the following 3 lines + # chunk_manager = ChunkManager(config_dict, init_device=init_device) + # gemini_manager = GeminiManager(placement_policy, chunk_manager) + # model = ZeroDDP(model, gemini_manager, pin_memory=True) + + zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) + # The same as the following 2 lines + # optimizer = HybridAdam(model.parameters(), lr=1e-3) + # zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + check_param(model, torch_model, pg) + + model.eval() + torch_model.eval() + + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) + assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + + zero_optim.step() + torch_optim.step() + check_param(model, torch_model, pg) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if world_size == 4: + run_gpt(tp_init_spec_func=init_megatron_spec) + else: + run_gpt(tp_init_spec_func=init_1d_col_spec) + run_gpt(tp_init_spec_func=init_1d_row_spec) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(4) diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..72820c6a1f0d30fb658ba48d44c44a6d1a9ffcae --- /dev/null +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward, + send_backward_recv_forward, send_forward, send_forward_recv_backward, + send_obj_meta) +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import get_dist_logger +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_on_exception + +BATCH_SIZE = 4 +SEQ_LENGTH = 2 +HIDDEN_SIZE = 16 + +CONFIG = dict(parallel=dict(pipeline=dict(size=4), tensor=dict(size=1, mode=None)), seed=1024) + + +def check_equal(A, B): + return torch.allclose(A, B, rtol=1e-5, atol=1e-3) + + +def check_forward(output_tensor, rank, logger): + dist.barrier() + if gpc.is_first_rank(ParallelMode.PIPELINE): + tensor = output_tensor.clone() + else: + tensor = recv_forward(output_tensor.shape) + logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor))) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + send_forward(tensor) + logger.info('Rank {} sent forward.'.format(rank)) + + +def check_backward(output_grad, rank, logger): + dist.barrier() + if gpc.is_last_rank(ParallelMode.PIPELINE): + grad = output_grad.clone() + else: + grad = recv_backward(output_grad.shape) + logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad))) + if not gpc.is_first_rank(ParallelMode.PIPELINE): + send_backward(grad) + logger.info('Rank {} sent backward.'.format(rank)) + + +def check_forward_backward(output_tensor, output_grad, rank, logger): + dist.barrier() + if not gpc.is_first_rank(ParallelMode.PIPELINE): + tensor = send_backward_recv_forward(output_grad, output_tensor.shape) + logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format( + rank, check_equal(tensor, output_tensor))) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + grad = send_forward_recv_backward(output_tensor, output_grad.shape) + logger.info('Rank {} sent forward received backward. Correct grad: {}'.format( + rank, check_equal(grad, output_grad))) + + +def check_comm(size, rank, prev_rank, next_rank, logger): + dtype = torch.float32 + device = get_current_device() + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + tensor = torch.randn(tensor_shape, dtype=dtype, device=device) + dist.all_reduce(tensor) + grad = torch.randn(grad_shape, dtype=dtype, device=device) + dist.all_reduce(grad) + check_forward(tensor, rank, logger) + check_backward(grad, rank, logger) + check_forward_backward(tensor, grad, rank, logger) + + +def run_check(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + logger = get_dist_logger() + rank = gpc.get_global_rank() + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank)) + logger.info('Distributed environment is initialzied.') + + check_comm(world_size, rank, prev_rank, next_rank, logger) + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_p2p(): + world_size = 4 + run_func = partial(run_check, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_p2p() diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..48f7296581348b188ce0df7df313b34a5b6641fd --- /dev/null +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -0,0 +1,96 @@ +# referenced from Megatron and used to testify communication + +import os +import os.path as osp +from functools import partial +from pathlib import Path + +import colossalai +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.initialize import launch +from colossalai.utils import free_port, get_dataloader, print_rank_0 +from colossalai.testing import rerun_on_exception +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + + +BATCH_SIZE = 8 + +CONFIG=dict( + NUM_MICRO_BATCHES=2, + parallel = dict( + pipeline=dict(size=2), + tensor=dict(size=1, mode=None) + ) +) + +def run_schedule(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + + print_rank_0('model is created') + + train_dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) + + train_dataloader = get_dataloader( + dataset=train_dataset, + shuffle=True, + add_sampler=True, + batch_size=BATCH_SIZE, + pin_memory=True, + ) + + # build criterion + criterion = torch.nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) + + # initialize + engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader) + + # build pipeline schedule + schedule = engine.schedule + + # run schedule + data_iter = iter(train_dataloader) + schedule.forward_backward_step(engine, data_iter) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_pipeline_schedule(): + world_size = 2 + run_func = partial(run_schedule, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_pipeline_schedule() diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..b013433293cd484512a2b6569a4f3d4c23c611bd --- /dev/null +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -0,0 +1,62 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.amp.amp_type import AMP_TYPE +from colossalai.logging import get_dist_logger +from colossalai.trainer import Trainer +from colossalai.utils import MultiTimer, free_port +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import parameterize, rerun_if_address_is_in_use + +BATCH_SIZE = 4 +IMG_SIZE = 32 +NUM_EPOCHS = 200 + +CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) + + +@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'nested_model']) +def run_trainer(model_name): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model = model_builder() + optimizer = optimizer_class(model.parameters(), lr=1e-3) + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) + + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + trainer.fit(train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5) + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_trainer_no_pipeline(): + world_size = 4 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_trainer_no_pipeline() diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..3698526a8e6c88b3fd7ca2c49ad8c34348ba90a4 --- /dev/null +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -0,0 +1,99 @@ +import os +from functools import partial +from pathlib import Path + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import PipelineSchedule +from colossalai.logging import get_dist_logger +from colossalai.trainer import Trainer +from colossalai.utils import MultiTimer, free_port, get_dataloader +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 +from colossalai.testing import rerun_if_address_is_in_use + +BATCH_SIZE = 4 +IMG_SIZE = 32 +NUM_EPOCHS = 200 + +CONFIG = dict( + NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2), +) + + +def run_trainer_with_pipeline(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + + # build dataloaders + train_dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ])) + + train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=BATCH_SIZE, + pin_memory=True, + drop_last=True) + + # build optimizer + optimizer = Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + engine, train_dataloader, *args = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + + trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5) + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_trainer_with_pipeline(): + world_size = 4 + run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_trainer_with_pipeline() diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac75fb00c86d23df7e0d2f1c46899c4fc131430 --- /dev/null +++ b/tests/test_utils/test_activation_checkpointing.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +import torch.nn.functional as F +from colossalai.context.parallel_mode import ParallelMode +from colossalai.context.random import add_seed, seed, set_mode, reset_seeds +from colossalai.utils.activation_checkpoint import checkpoint + + +def forward(x, weight): + out = torch.matmul(x, weight) + with seed(ParallelMode.DATA): + out_ = F.dropout(out, p=0.4, training=True) + return out_ + + +def forward_inplace_ckpt(x, weight, cpu_offload=False): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + + def ckpt0(x): + return F.relu(x, inplace=True) + + out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False) + return out + + +def forward_inplace(x, weight): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + out = F.relu(out, inplace=True) + return out + + +@pytest.mark.gpu +@pytest.mark.parametrize("use_reentrant", [True, False]) +@pytest.mark.parametrize("cpu_offload", [True, False]) +def test_activation_checkpointing(cpu_offload, use_reentrant): + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() + + # We put initilization here to avoid change cuda rng state below + inputs = torch.rand(2, 2, requires_grad=True, device='cuda') + weight = torch.rand(2, 4, requires_grad=True, device='cuda') + + # Get a copy of input tensors + inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_.data.copy_(inputs.data) + weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_.data.copy_(weight.data) + + add_seed(ParallelMode.GLOBAL, 1024) + add_seed(ParallelMode.DATA, 1026) + set_mode(ParallelMode.GLOBAL) + global_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.DATA) + data_parallel_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.GLOBAL) + + out = forward(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + + # Extra test for use_reentrant=False + if use_reentrant == False: + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests will fail if running together with this test + # as other tests can't overwrite the seed set by this test + reset_seeds() + + +if __name__ == "__main__": + test_activation_checkpointing(False, False) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..96710106b8bca996270e780648d53bc27eac4085 --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import colossalai.nn as col_nn +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_1d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_1d(): + world_size = 8 + run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_1d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a3eeaa479a47809b7934a1691127b14e405fc053 --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import colossalai.nn as col_nn +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port, get_current_device, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_2d(): + world_size = 8 + run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_2d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..9baddaf5a6671f39903f854e1881c9482c5c8063 --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import colossalai.nn as col_nn +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port, get_current_device, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2p5d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_2p5d(): + world_size = 8 + run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_2p5d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..481e30cdd7e53a42922c8f7f7699f4057fc0e19a --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import colossalai.nn as col_nn +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port, get_current_device, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_3d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_3d(): + world_size = 8 + run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_3d() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..6d89fb90c574e9b06de571760b59441e75bf2b33 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.utils import build_checkpoints +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def test_global_model(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + global_state_dict = model_checkpoints[0] + assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) + for k, v in orig_state_dict.items(): + assert torch.equal(v, global_state_dict[k]) + + +def test_global_model_shard(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) + assert len(model_checkpoints) == 2 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) + assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 + for k, v in orig_state_dict.items(): + for state_dict in model_checkpoints: + if k in state_dict: + assert torch.equal(v, state_dict[k]) + + +def test_global_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) + assert len(optimizer_checkpoints) == 1 + assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} + for state in meta['paired_os'].values(): + for k, is_paired in state.items(): + if k == 'step': + assert not is_paired + else: + assert is_paired + orig_state_dict = optimizer.state_dict() + state_dict = optimizer_checkpoints[0] + for k, orig_state in orig_state_dict['state'].items(): + state = state_dict['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v2 == v2 + assert orig_state_dict['param_groups'] == state_dict['param_groups'] + + +def test_global_optimizer_shard(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) + assert len(optimizer_checkpoints) == 2 + assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] + orig_state_dict = optimizer.state_dict() + assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( + optimizer_checkpoints[1]['state'].keys()) + assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 + for k, orig_state in orig_state_dict['state'].items(): + state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ + 'state'] else optimizer_checkpoints[1]['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + + assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] + + +def test_dist_model_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] + assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] + dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + + +if __name__ == '__main__': + test_global_model() + test_global_model_shard() + test_global_optimizer() + test_global_optimizer_shard() + test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py new file mode 100644 index 0000000000000000000000000000000000000000..780c13dc534a42ce23648b6e764fb6bc31e7976a --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -0,0 +1,188 @@ +from copy import deepcopy +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta) +from torch import Tensor +from torch.nn import Module +from torch.optim import Adam, Optimizer + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): + with torch.no_grad(): + for p in model.parameters(): + p.fill_(scalar) + for state in optimizer.state.values(): + for v in state.values(): + if isinstance(v, Tensor): + v.fill_(scalar) + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) +def test_save_global_load_global(max_shard_size_gb: float): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) + new_model, new_optimizer = prepare_model_optim() + load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) + check_model_state_dict(model.state_dict(), new_model.state_dict()) + check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func() + + +def launch_dist(fn, world_size: int): + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + + +def save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + reset_model_optim(model, optmizer) + world_size = dist.get_world_size() + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank]) + + +def load_and_check_dist(dir_name: str): + world_size = dist.get_world_size() + model, optmizer = prepare_model_optim(shard=True) + reset_model_optim(model, optmizer) + model_state_dict = deepcopy(model.state_dict()) + optimizer_state_dict = deepcopy(optmizer.state_dict()) + reset_model_optim(model, optmizer, 1) + load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size)) + check_model_state_dict(model_state_dict, model.state_dict()) + check_optim_state_dict(optimizer_state_dict, optmizer.state_dict()) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_global_load_dist(): + model, optimizer = prepare_model_optim() + reset_model_optim(model, optimizer) + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist_load_dist(): + with TemporaryDirectory() as dir_name: + # save tp + dp + fn = partial(save_dist, dir_name, False) + launch_dist(fn, 2) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + with TemporaryDirectory() as dir_name: + # save tp + zero + fn = partial(save_dist, dir_name, True) + launch_dist(fn, 4) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + launch_dist(fn, 4) + + +if __name__ == '__main__': + test_save_global_load_global(80 / 1024**3) + test_save_global_load_global(0) + test_save_global_load_dist() + test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..04e454dcb713545085306c5caae0deca69567f4b --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -0,0 +1,127 @@ +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import save, merge +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from tempfile import TemporaryDirectory +from torch.optim import Adam +from functools import partial +import torch +import os +import pytest +import colossalai +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_merge_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + func() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + dp_world_size = dist.get_world_size() // 2 + if not zero: + dist_metas = { + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + } + else: + dist_metas = { + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_merge_tp_dp(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 5 + global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 + meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) + assert meta['dist_meta'] is None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 20 + optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 + + +if __name__ == '__main__': + test_merge_global() + test_merge_tp_dp(False) + test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py new file mode 100644 index 0000000000000000000000000000000000000000..5da2ae4fe1f8f0e0797d2b733b4fe4ecf1778a67 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge_param.py @@ -0,0 +1,101 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param + + +def test_unflatten_zero_param_even() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_unflatten_zero_param_uneven() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_row() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_col() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d_reverse() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_hybrid() -> None: + dist_metas = [ + ParamDistMeta(i % 2, + 2, + i // 2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_numel=4, + zero_orig_shape=[2, 2]) for i in range(12) + ] + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_dummy() -> None: + dist_metas = [ParamDistMeta(0, 1, 0, 1)] + orig_tensor = torch.rand(4, 6) + merged_tensor = merge_param([orig_tensor], dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +if __name__ == '__main__': + test_unflatten_zero_param_even() + test_unflatten_zero_param_uneven() + test_gather_tp_param_1d_row() + test_gather_tp_param_1d_col() + test_gather_tp_param_2d() + test_gather_tp_param_2d_reverse() + test_merge_param_hybrid() + test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py new file mode 100644 index 0000000000000000000000000000000000000000..6e76f3167e3147d74cc865212422540ca38e9e16 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -0,0 +1,149 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import redist, save +from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, + RedistMeta) +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +def check_checkpoint_shape(dir_name: str): + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + for meta_name in global_meta['meta']: + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta['dist_meta'] is not None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 10 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 + + +def test_global_to_dist(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + check_checkpoint_shape(output_dir) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + func() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank]) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_dist_to_dist(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + if not zero: + assert len(os.listdir(output_dir)) == 0 + else: + check_checkpoint_shape(output_dir) + + +if __name__ == '__main__': + test_global_to_dist() + test_dist_to_dist(False) + test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff9d0aa22177a3cde101eae7d6ef416798940db --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -0,0 +1,147 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME) +from colossalai.utils.checkpoint_io.io import save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from torch import Tensor +from torch.optim import Adam + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_overwrite(): + model = DummyModel() + with TemporaryDirectory() as dir_name: + with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: + pass + with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): + save(dir_name, model) + + +def test_save_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + assert len(os.listdir(dir_name)) == 5 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 1 + assert len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + check_model_state_dict(model.state_dict(), model_state_dict) + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) + other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) + assert len(other_state_dict) == 0 + + +def test_save_global_shard(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + assert len(os.listdir(dir_name)) == 7 + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 + model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] + assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 + check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) + optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] + assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 + assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] + check_optim_state_dict( + optimizer.state_dict(), { + 'state': { + **optimizer_state_dicts[0]['state'], + **optimizer_state_dicts[1]['state'] + }, + 'param_groups': optimizer_state_dicts[0]['param_groups'] + }) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func() + + +def run_save_dist(dir_name): + model, optmizer = prepare_model_optim() + dist_metas = { + 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), + 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist(): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name) + world_size = 2 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + assert len(os.listdir(dir_name)) == 8 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 2 + for rank, meta_name in enumerate(global_meta['meta']): + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta.get('dist_meta', None) is not None + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict + + +if __name__ == '__main__': + test_overwrite() + test_save_global() + test_save_global_shard() + test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py new file mode 100644 index 0000000000000000000000000000000000000000..8b83caa12359ff8c404e65f96701a6c390f9792d --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py @@ -0,0 +1,137 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamRedistMeta +from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param + + +def test_flatten_zero_param_even() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_flatten_zero_param_uneven() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 + flat_tensors = flat_tensors[1:-1] + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 + unmerged_tensors = unmerged_tensors[1:-1] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_split_tp_param_1d_row() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_1d_col() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d_reverse() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_unmerge_param_hybrid() -> None: + redist_meta = ParamRedistMeta(2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_start_dp_rank=0, + zero_offsets=[0, 1]) + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 + for tp_rank in range(6): + for dp_rank in range(2): + assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) + + +def test_unmerge_param_dummy() -> None: + redist_meta = ParamRedistMeta(1, 1) + orig_tensor = torch.rand(4, 6) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 + assert torch.equal(orig_tensor, unmerged_tensors[0][0]) + + +if __name__ == '__main__': + test_flatten_zero_param_even() + test_flatten_zero_param_uneven() + test_split_tp_param_1d_row() + test_split_tp_param_1d_col() + test_split_tp_param_2d() + test_split_tp_param_2d_reverse() + test_unmerge_param_hybrid() + test_unmerge_param_dummy() diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ea75fffc3629e1a160948e0cb12317bbbccbdd --- /dev/null +++ b/tests/test_utils/test_colo_checkpoint.py @@ -0,0 +1,217 @@ +import os, shutil +import torch +import pytest +from copy import deepcopy +from functools import partial + +import torch.multiprocessing as mp +import torch.distributed as dist + +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import MultiplicativeLR +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup +from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint +from colossalai.nn.optimizer import ColossalaiOptimizer + +from tests.components_to_test.registry import non_distributed_component_funcs + + +def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_col_linear(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_row_embedding(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_col_embedding(weight, pg): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + if 'embed' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + if 'proj1' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) + if 'proj2' in name and 'weight' in name: + init_1d_row_linear(p, pg) + if 'classifier' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) + + +def check_param_equal(model, torch_model): + for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): + assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape) + + +def remove(path): + """ param could either be relative or absolute. """ + if os.path.isfile(path) or os.path.islink(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + else: + raise ValueError("file {} is not a file or dir.".format(path)) + + +def compare_optims(optim1, optim2): + state1 = optim1.state_dict()['state'] + state2 = optim2.state_dict()['state'] + for k, p1 in state1.items(): + if k not in state2: + continue + p2 = state2[k] + for n, t1 in p1.items(): + if n not in p2: + continue + t2 = p2[n] + if isinstance(t1, ColoTensor): + assert isinstance(t2, ColoTensor) + assert torch.allclose(t1, t2, rtol=0, atol=0) + + +def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + if use_mp_reload: + if 'bert' == model_name: + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + init_1d_row_linear(p, pg) + # num_class = vocab_size = 30524 | (30524, 8) + elif 'word_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = seq_len = 512 | (512, 8) + elif 'position_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = type_vocab_size = 2 | (2, 8) + elif 'token_type_embeddings' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + elif p.process_group.tp_world_size() == 1: + p.set_process_group(pg) + elif "simple_net" == model_name: + init_spec_func(model, pg) + + model_reload = deepcopy(model) + model = model.cuda() + model.eval() + + model_reload = model_reload.cuda() + model_reload.eval() + + opt_class = torch.optim.Adam + colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) + colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) + + for i, (data, label) in enumerate(train_dataloader): + + # Zero grad + colo_optimizer.zero_grad() + colo_optimizer_reload.zero_grad() + + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group()) + dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + output_reload = model_reload(data) + loss = criterion(output, label) + loss_reload = criterion(output_reload, label) + else: + loss = model(data, label) + loss_reload = model_reload(data, label) + + loss.backward() + loss_reload.backward() + + colo_optimizer.step() + colo_optimizer_reload.step() + + if i > 2: + break + + if not os.path.isdir('./checkpoint') and rank == 0: + os.mkdir('./checkpoint') + dist.barrier() + + save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) + load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) + + check_param_equal(model, model_reload) + compare_optims(colo_optimizer, colo_optimizer_reload) + + if rank == 0: + remove('./checkpoint') + dist.barrier() + + +def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=world_size) + + # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context + for model_name in ['bert']: + _run_checkpoint(model_name, + init_1d_row_for_linear_weight_spec, + use_ddp, + use_mp_reload, + test_scheduler=test_scheduler, + pg=pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('use_ddp', [False]) +@pytest.mark.parametrize('use_mp_reload', [True, False]) +# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): + run_func = partial(run_dist, + world_size=world_size, + port=free_port(), + use_ddp=use_ddp, + use_mp_reload=use_mp_reload, + test_scheduler=test_scheduler) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine") diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py new file mode 100644 index 0000000000000000000000000000000000000000..0ecb7446c788c567a739e99a039c699899e4a907 --- /dev/null +++ b/tests/test_utils/test_commons.py @@ -0,0 +1,44 @@ +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.zero.sharded_param import ShardedTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +import colossalai + +import torch + +import torch.multiprocessing as mp + + +def run_tensor_move(rank): + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + + src_t = torch.ones(2, 3).cuda() + tgt_t = torch.zeros(2, 3) + + colo_model_data_tensor_move(src_t, tgt_t) + assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + src_t = torch.ones(2, 3) + tgt_t = torch.zeros(2, 3).cuda().half() + colo_model_data_tensor_move(src_t, tgt_t) + # the src_t has been removed + assert (src_t.numel() == 0) + assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + src_t = ShardedTensor(torch.ones(2, 3)) + tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) + colo_model_data_tensor_move(src_t, tgt_t) + assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + assert (tgt_t.device.type == 'cuda') + colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) + assert (tgt_t.device.type == 'cpu') + + +@rerun_if_address_is_in_use() +def test_tensor_move(): + mp.spawn(run_tensor_move, nprocs=1) + + +if __name__ == '__main__': + test_tensor_move() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..58e3b21d97eb3d7a56a707ee5a50f9539e80df9c --- /dev/null +++ b/tests/test_utils/test_flash_attention.py @@ -0,0 +1,146 @@ +import pytest +import torch +from einops import rearrange + +from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON + +if HAS_FLASH_ATTN: + from colossalai.kernel.cuda_native.flash_attention import ( + MaskedFlashAttention, + flash_attention_q_k_v, + flash_attention_q_kv, + flash_attention_qkv, + ) + +if HAS_TRITON: + from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention + +if HAS_MEM_EFF_ATTN: + from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention + + +def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + return ref_out + + +@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + + ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + tri_out = triton_flash_attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-3) + assert torch.allclose(ref_dv, tri_dv, atol=1e-3) + assert torch.allclose(ref_dk, tri_dk, atol=1e-3) + assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + + +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + + # reference implementation + ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # flash implementation + q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v]) + dout = rearrange(dout, 'z h n d -> (z n) h d').detach() + for i in range(3): + if i == 0: + tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True) + elif i == 1: + kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1) + tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True) + else: + qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1) + tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True) + + tri_out.backward(dout, retain_graph=True) + + if i == 0: + tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq, tri_dk, tri_dv)) + elif i == 1: + tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout) + tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) + else: + tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout) + tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) + + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-3) + assert torch.allclose(ref_dv, tri_dv, atol=1e-3) + assert torch.allclose(ref_dk, tri_dk, atol=1e-3) + assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + + +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1) + + qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + attention_mask = torch.randint(2, (Z, H)).cuda().bool() + + out = attn(qkv, attention_mask) + + dout = torch.rand_like(out) + out.backward(dout) + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)]) +def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1) + + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + + out = attn(q, k, v, attention_mask=LowerTriangularMask()) + + dout = torch.rand_like(out) + out.backward(dout) + + +if __name__ == '__main__': + test_flash_attention(3, 4, 2, 16) diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py new file mode 100644 index 0000000000000000000000000000000000000000..97efb3367490e772f351939bec7949fd86ad4da3 --- /dev/null +++ b/tests/test_utils/test_lazy_init_ctx.py @@ -0,0 +1,51 @@ +import torch +from colossalai.utils.model.lazy_init_context import LazyInitContext +from torchvision.models import resnet34 +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + + +def test_lazy_init_with_meta(): + ctx = LazyInitContext(to_meta=True) + with ctx: + model = resnet34(num_classes=10) + + for param in model.parameters(): + assert param.is_meta + for buffer in model.buffers(): + assert buffer.is_meta + + ctx.lazy_init_parameters(model) + + for name, param in model.named_parameters(): + assert not param.is_meta, name + + for buffer in model.buffers(): + assert not buffer.is_meta + + +def test_lazy_init_without_meta(): + ctx = LazyInitContext(to_meta=False) + with ctx: + model = resnet34(num_classes=10) + + for param in model.parameters(): + assert not param.is_meta + for buffer in model.buffers(): + assert not buffer.is_meta + + conv1_weight_before_init = model.conv1.weight.clone() + ctx.lazy_init_parameters(model) + conv1_weight_after_init = model.conv1.weight.clone() + + assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init) + + +if __name__ == '__main__': + test_lazy_init_with_meta() + test_lazy_init_without_meta() diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..46a5aeba505b733b669f409b514c8c72b7defdb6 --- /dev/null +++ b/tests/test_utils/test_memory.py @@ -0,0 +1,32 @@ +import pytest + +import colossalai +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity +from colossalai.utils import free_port + +from functools import partial +import torch.multiprocessing as mp + + +def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): + frac1 = colo_device_memory_capacity(get_current_device()) + colo_set_process_memory_fraction(0.5) + frac2 = colo_device_memory_capacity(get_current_device()) + assert frac2 * 2 == frac1 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [3, 4]) +def test_memory_utils(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_memory_utils(world_size=2) diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py new file mode 100644 index 0000000000000000000000000000000000000000..259286663033a91f74213e232a385a8742677298 --- /dev/null +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -0,0 +1,79 @@ +from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port, get_current_device +from torch.nn.utils import clip_grad_norm_ +from functools import partial +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils.common import clip_grad_norm +from torch.nn.parameter import Parameter + + +def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): + return abs(num - other) <= atol + rtol * other + + +def shard_param(p: ColoParameter) -> None: + pg = p.get_process_group() + p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()])) + p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach() + + +def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: + pg = colo_p.get_process_group() + if p.shape != colo_p.shape: + grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] + else: + grad = p.grad + assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}' + + +@parameterize('dtype', [torch.float]) +@parameterize('device', ['mixed', 'cuda', 'cpu']) +@parameterize('norm_type', [2.0, 3.0, float('inf')]) +def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): + print(f'{world_size}, {dtype}, {device}, {norm_type}') + cuda_device = get_current_device() + devices = [cuda_device] * 4 + if device == 'cpu': + devices = [torch.device('cpu')] * 4 + elif device == 'mixed': + devices = [cuda_device] * 2 + [torch.device('cpu')] * 2 + pg = ProcessGroup(tp_degree=world_size) + params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] + colo_params = [ + ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4) + ] + for p, colo_p in zip(params, colo_params): + grad = torch.rand_like(p) + p.grad = grad + colo_p.grad = grad.clone().detach() + shard_param(colo_params[0]) + shard_param(colo_params[2]) + torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) + colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) + assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}' + for p, colo_p in zip(params, colo_params): + check_grad_equal(p, colo_p) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_grad_clip_norm(world_size=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_clip_grad(world_size: int): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_clip_grad(2) diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdae88464b17f624d66f470241aecfa57b3a8e7 --- /dev/null +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import copy + +import colossalai +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.logging import disable_existing_loggers +from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy +from functools import partial +from colossalai.testing import parameterize, rerun_if_address_is_in_use + + +def checkpoint_wrapper(module, enable=True): + if enable: + module.forward = partial(checkpoint, module.forward, False) + return module + + +class Net(nn.Module): + + def __init__(self, checkpoint=False) -> None: + super().__init__() + self.fc1 = nn.Linear(5, 5) + self.fc2 = nn.Linear(5, 5) + self.fc3 = nn.Linear(5, 1) + if checkpoint: + self.fc1 = checkpoint_wrapper(self.fc1) + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0): + model.train() + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=enable_autocast): + y = model(x) + loss = y.sum() + loss = loss.float() + loss.backward() + clip_grad(model, norm_type) + optimizer.step() + + +def clip_grad(model, norm_type): + if isinstance(model, DDP): + clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type) + else: + clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type) + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def check_grads(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_grad = zero_p.grad.clone().to(p.device) + chunks = torch.flatten(p.grad).chunk(4) + if rank >= len(chunks): + continue + grad = chunks[rank] + if zero_p.zero_shard_padding > 0: + zero_grad = zero_grad[:-zero_p.zero_shard_padding] + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose) + + +def check_params(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_shard_padding = zero_p.zero_shard_padding + zero_p = zero_p.clone().to(p.device) + chunks = torch.flatten(p).chunk(4) + if rank >= len(chunks): + continue + p = chunks[rank] + if zero_shard_padding > 0: + zero_p = zero_p[:-zero_shard_padding] + assert p.dtype == zero_p.dtype + assert allclose(p, zero_p, loose=loose) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_clip_grad(): + world_size = 4 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_clip_grad() diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6cd75a6a609ae8ab50875a2854a6e23c994caa --- /dev/null +++ b/tests/test_zero/common.py @@ -0,0 +1,139 @@ +from functools import partial + +import torch +import torch.distributed as dist +from colossalai.logging import get_dist_logger +from colossalai.utils import checkpoint +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 + +LOGGER = get_dist_logger('zero_test') + +MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None))) + +_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + tensor_placement_policy='cuda', + gradient_predivide_factor=1.0, + shard_strategy=TensorShardStrategy(), + reuse_fp16_shard=False) + +_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32) + +ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), + zero=dict( + model_config=_ZERO_MODEL_CONFIG, + optimizer_config=_ZERO_OPTIMIZER_CONFIG, + ), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + +CONFIG = dict(fp16=dict(mode=None,), + zero=dict(level=3, + verbose=False, + offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False), + offload_param_config=dict(device='cpu', + pin_memory=True, + buffer_count=5, + buffer_size=1e8, + max_in_cpu=1e9)), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + + +def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + if isinstance(model, ShardedModelV2): + model.backward(loss) + else: + loss.backward() + + +def checkpoint_wrapper(module, enable=True): + if enable: + module.forward = partial(checkpoint, module.forward) + return module + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def check_grads(model, zero_model, loose=False): + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_grad = zero_p.grad.clone().to(p.device) + grad = p.grad.float() + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose) + + +def check_params(model, zero_model, loose=False): + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_p = zero_p.clone().to(p.device) + # assert p.dtype == zero_p.dtype + assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}" + + +def check_grads_padding(model, zero_model, loose=False): + rank = dist.get_rank() + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): + # zero_grad = zero_p.grad.clone().to(p.device) + if zero_p.colo_attr.is_replicated: + zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device) + chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + grad = chunks[rank].float() + if zero_grad.size(0) > grad.size(0): + zero_grad = zero_grad[:grad.size(0)] + else: + zero_grad = zero_p.colo_attr.grad_payload + grad = p.grad.to(zero_grad.dtype) + + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}' + + +def check_params_padding(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_p = zero_p.clone().to(p.device) + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank] + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + assert p.dtype == zero_p.dtype + assert allclose(p, zero_p, loose=loose) + + +def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): + rank = dist.get_rank() + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): + if zero_p.colo_attr.param_is_sharded: + zero_p = zero_p.colo_attr.data_payload.to(p.device).float() + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank].float() + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + else: + zero_p = zero_p.colo_attr.data_payload.to(p.device) + + assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) + assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/low_level_zero/test_grad_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..c23b3a3e8fd861ffbf3da73883950ea6d4c24ebe --- /dev/null +++ b/tests/test_zero/low_level_zero/test_grad_acc.py @@ -0,0 +1,167 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing.random import seed_all +from colossalai.utils import free_port +from colossalai.zero import LowLevelZeroOptimizer + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def exam_zero_1_2_grad_acc(): + local_rank = torch.distributed.get_rank() + seed_all(2009) + + # create model + zero1_model = TestModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=32, + clip_grad_norm=1.0, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=32, + clip_grad_norm=1.0) + # create data + seed_all(2021 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data): + # zero-dp forward + zero1_output = zero1_model(cur_data) + zero2_output = zero2_model(cur_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.sum().float()) + zero2_optimizer.backward(zero2_output.sum().float()) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer.sync_grad() + zero2_optimizer.sync_grad() + + fwd_bwd_func(0, input_data1) + fwd_bwd_func(1, input_data2) + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_grad_acc(): + local_rank = torch.distributed.get_rank() + grad_scale = 32 + seed_all(2008) + + # create models + zero_model = TestModel() + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda() + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=False, + initial_scale=grad_scale, + reduce_bucket_size=262144, + clip_grad_norm=1.0) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + + # create data + seed_all(2022 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data, check_flag): + # zero-dp forward + zero_output = zero_model(cur_data) + + # torch-ddp forward + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + + # zero-dp backward + zero_optimizer.backward(zero_output.sum().float()) + # torch-ddp backward + torch_output.sum().backward() + + if check_flag: + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + unscale_grad = z1p.grad / grad_scale + # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) + assert torch.equal(p.grad, unscale_grad) + + zero_optimizer.sync_grad() + + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) + + zero_optimizer.step() + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data))) + assert_close(p.data, z1p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_grad_acc() + # exam_zero_1_2_grad_acc() + + +@pytest.mark.dist +def test_grad_accumulation(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_accumulation() diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py new file mode 100644 index 0000000000000000000000000000000000000000..b02d3a6a448657be3b53af02205b69f4ea7c9018 --- /dev/null +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -0,0 +1,186 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing.random import seed_all +from colossalai.utils import free_port +from colossalai.zero import LowLevelZeroOptimizer + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def half_close(a, b, loose=False): + rtol = None + atol = None + if loose: + rtol = 5e-2 + atol = 5e-4 + + a = a.detach().half() + b = b.detach().half() + + assert_close(a, b, rtol=rtol, atol=atol) + + +def exam_zero_1_2(): + """ + In this test, we want to test whether zero stage 1 and 2 + deliver the same numerical results despite different communication + pattern + + we use these prefixes to differentiate the zero stage + oss: partition optimizer states + pg: partition gradients and optimizer states + + """ + local_rank = torch.distributed.get_rank() + seed_all(2001) + + # create model + zero1_model = TestModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=128, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=128) + # create data + seed_all(2001 + local_rank) + input_data = torch.randn(32, 128).cuda() + + zero1_output = zero1_model(input_data) + zero2_output = zero2_model(input_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.mean().float()) + zero2_optimizer.backward(zero2_output.mean().float()) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer.sync_grad() + zero2_optimizer.sync_grad() + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_torch_ddp(): + """ + In this test, two pairs of model and optimizers are created. + 1. zero: use sharded optimizer and fp16 parameters + 2. torch: use torch DDP and fp32 parameters + + We feed these two sets of models with the same input and check if the + differences in model output and updated parameters are within tolerance. + """ + local_rank = torch.distributed.get_rank() + seed_all(1453) + + # create models + zero_model = TestModel() + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda().half() + # torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + torch_model = torch_model.cuda() + + # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # half_close(p.data, z1p.data) + + # create optimizer + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=262144) + + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + seed_all(1453 + local_rank) + # create + input_data = torch.rand(32, 128).cuda() + + # zero-dp forward + zero_output = zero_model(input_data.half()) + + # torch-ddp forward + torch_output = torch_model(input_data) + half_close(zero_output, torch_output, loose=True) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float()) + + # torch-ddp backward + torch_output.mean().backward() + + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + half_close(p.grad, z1p.grad, loose=True) + + # zero-dp step + zero_optimizer.sync_grad() + zero_optimizer.step() + + # torch ddp step + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, torch.max(torch.abs(p.data - z1p.data))) + half_close(p.data, z1p.data, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_torch_ddp() + exam_zero_1_2() + + +@pytest.mark.dist +def test_zero_1_2(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_1_2() diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py new file mode 100644 index 0000000000000000000000000000000000000000..695446dd9e996c51ae378facec9e30476c889b14 --- /dev/null +++ b/tests/test_zero/test_found_inf.py @@ -0,0 +1,72 @@ +from functools import partial + +import colossalai +from colossalai.utils.cuda import get_current_device +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_zero.test_sharded_optim_v2 import _run_step + +from common import CONFIG + + +@parameterize("cpu_offload", [True, False]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers'] + shard_strategy = shard_strategy_class() + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=True, + ) + + sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + assert zero_model.overflow_counter == 0 + data, label = data.cuda(), label.cuda() + _run_step(zero_model, sharded_optim, data, label, criterion, False) + for param in zero_model.parameters(): + assert not has_inf_or_nan(param.colo_attr.data_payload) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_found_inf() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_found_inf(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_found_inf(world_size=2) diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c2e2f6ca5236f5f34e676011ff782ea19e1f3c --- /dev/null +++ b/tests/test_zero/test_init_context.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from common import CONFIG + +import colossalai +from colossalai.gemini.memory_tracer.utils import colo_model_mem_usage +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_device_memory_used +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_zero_init") + + for get_components_func in non_distributed_component_funcs: + model_builder, _, _, _, _ = get_components_func() + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + continue + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = model_builder(checkpoint=True) + + for param in model.parameters(): + assert hasattr(param, 'colo_attr') + assert param.colo_attr.sharded_data_tensor.dtype == torch.half + assert param.colo_attr.sharded_data_tensor.is_sharded + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + + cuda_mem_use, _ = colo_model_mem_usage(model) + model_data_cuda_mem_MB = cuda_mem_use / 1e6 + logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) + sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6 + logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) + logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_init_context(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_init_context(4) diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..95a9dee38acf9254b772e9c22b99efae5f3fa776 --- /dev/null +++ b/tests/test_zero/test_shard_model_v2.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from common import CONFIG, check_grads_padding, run_fwd_bwd +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize("enable_autocast", [True]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] + shard_strategy = shard_strategy_class() + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + + model = DDP(model, device_ids=[torch.cuda.current_device()]) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) + + check_grads_padding(model, zero_model, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_shard_model_v2(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_shard_model_v2(world_size=2) diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py new file mode 100644 index 0000000000000000000000000000000000000000..8db2b7e796045210c720122384d3f74d46289e64 --- /dev/null +++ b/tests/test_zero/test_shard_param.py @@ -0,0 +1,95 @@ +from copy import deepcopy +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.sharded_param import ShardedTensor +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from tests.test_zero.common import CONFIG, allclose +from colossalai.gemini.stateful_tensor import StatefulTensor + + +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_shard_tensor_with_strategy(shard_strategy_class, world_size): + t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) + assert list(t.origin_shape) == [world_size * 2, 3] + assert list(t.shape) == [world_size * 2, 3] + + shard_strategy = shard_strategy_class() + + # test shard strategy + shard_strategy.shard([t]) + assert list(t.shape) == [6], f"{list(t.shape)} vs 6" + shard_strategy.gather([t]) + assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}" + + +def _run_shard_tensor(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_shard_tensor_with_strategy(world_size=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_shard_tensor(world_size): + run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +def _run_shard_param_v2(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + param = torch.nn.Parameter(torch.randn(2, 3)) + param_ref = deepcopy(param) + sparam = ShardedParamV2(param=param) + + allclose(sparam.data_payload, param_ref.data) + + # Test get memory usage + sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" + + sparam.set_data_none() + assert (param.data.numel() == 0) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + # 4 is size of dummy tensor of param.data + assert cpu_mem_use == 2 * 3 * 4 * 2 + + sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) + sparam.set_data_none() + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cuda_mem_use == 0 + + # append a grad to torch param + param.data = sparam.data_payload + param.grad = torch.randn(2, 3) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" + assert cuda_mem_use == 0 + + # reuse torch grad for sparam + sparam.saved_grad = StatefulTensor(param.grad) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cuda_mem_use == 0 + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_shard_param_v2(world_size): + run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + # test_shard_tensor(2) + test_shard_param_v2(2) diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_sharded_optim_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c42930b2813eff5987a72b48f02b985270dd6f --- /dev/null +++ b/tests/test_zero/test_sharded_optim_state_dict.py @@ -0,0 +1,93 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from functools import partial +from tests.test_tensor.common_utils import set_seed +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import parameterize +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.tensor import ProcessGroup + + +def init_zero(model_builder, placement_policy): + device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu') + shard_strategy = TensorShardStrategy() + with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True): + model = model_builder() + model = ShardedModelV2( + model, + shard_strategy, + tensor_placement_policy=placement_policy, + reuse_fp16_shard=True, + ) + optim = HybridAdam(model.parameters(), lr=1e-3) + optim = ShardedOptimizerV2(model, optim, initial_scale=32) + return model, optim + + +def run_step(model, optim, criterion, data, label): + optim.zero_grad() + logits = model(data) + loss = criterion(logits, label) + optim.backward(loss) + optim.step() + + +def check_state_dict_eq(state_dict, other): + for p, state in state_dict['state'].items(): + other_state = other['state'][p] + for k, v in state.items(): + if isinstance(v, torch.Tensor): + assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}' + else: + assert v == other_state[k] + + +@parameterize('placement_policy', ['cuda', 'cpu']) +def run_nested_model(placement_policy): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + set_seed(42) + model, optim = init_zero(model_builder, placement_policy) + set_seed(42) + model_copy, optim_copy = init_zero(model_builder, placement_policy) + + model.train() + model_copy.train() + pg = ProcessGroup() + set_seed(pg.dp_local_rank()) + data_iter = iter(train_dataloader) + + data, label = map(lambda x: x.cuda(), next(data_iter)) + run_step(model, optim, criterion, data, label) + optim_copy.load_state_dict(optim.state_dict()) + check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) + + data, label = map(lambda x: x.cuda(), next(data_iter)) + run_step(model_copy, optim_copy, criterion, data, label) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_nested_model() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_sharded_optim_state_dist(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_optim_state_dist(2) diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..221915167925b2fc18ad5278bda4194fc88371dc --- /dev/null +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -0,0 +1,115 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from common import CONFIG, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import CPUAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs + + +def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): + model.train() + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + optimizer.step() + + +@parameterize("cpu_offload", [True, False]) +@parameterize("use_cpuadam", [True, False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] + shard_strategy = shard_strategy_class() + + if use_cpuadam and cpu_offload is False: + return + if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam): + return + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=use_cpuadam, + ) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda().float() + + if use_cpuadam: + optimizer_class = CPUAdam + optim = optimizer_class(model.parameters(), lr=1e-3) + sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + if dist.get_world_size() > 1: + apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()]) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data, label = data.cuda(), label.cuda() + _run_step(apex_model, apex_optimizer, data, label, criterion, False) + _run_step(zero_model, sharded_optim, data, label, criterion, False) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) + for param in model.parameters(): + assert not has_inf_or_nan(param) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_sharded_optim_v2() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_sharded_optim_v2(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_optim_v2(world_size=2) diff --git a/tests/test_zero/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_sharded_optim_with_sync_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5b315188a313e14551378b944aa9ac547a4da9 --- /dev/null +++ b/tests/test_zero/test_sharded_optim_with_sync_bn.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from torchvision.models import resnet50 + + +def run_dist(rank, world_size, port): + # this test only runs on resnet18 + # as this model has sync batch normalization + # need to configure cudnn deterministic so that + # randomness of convolution layers will be disabled + zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) + colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + + with ZeroInitContext(target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True): + model = resnet50() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() + + engine, *args = colossalai.initialize(model, optimizer, criterion) + + # train for dummy iterations + engine.train() + for _ in range(2): + data = torch.rand(4, 3, 128, 128).cuda().half() + label = torch.randint(0, 10, size=(4,)).cuda() + engine.zero_grad() + out = engine(data) + loss = engine.criterion(out, label) + engine.backward(loss) + engine.step() + + # test + # need to make sure the batch norm stats are synchronized + # so that given the same input, the model will produce the same + # output on different ranks + engine.eval() + data = torch.rand(4, 3, 128, 128).cuda().half() + dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA)) + + # predict + out = engine(data) + + # test if results are equal + tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)] + tensor_list.insert(rank, out) + dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA)) + + assert torch.all(tensor_list[0] == tensor_list[1]), \ + 'expected the output from different ranks to be the same, but got different values' + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_sharded_optim_with_sync_bn(): + """ + This test is to make sure that buffers are synchronized between ranks + when using ZeRO. An example of module buffer is the running stats of + BatchNormalization layer, i.e. mean and var. + + If the buffers are not synchronized, the model will produce different + output even though the input and parameters are the same. This is not + wanted if we are doing predictions. + + """ + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_optim_with_sync_bn() diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac9b151e4d68d9ed136db166d3c31a165297ddd --- /dev/null +++ b/tests/test_zero/test_state_dict.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from copy import deepcopy +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs + +from common import CONFIG + + +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_zero_state_dict(shard_strategy_class): + test_models = ['repeated_computed_layers', 'resnet18'] + shard_strategy = shard_strategy_class() + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + + zero_state_dict = zero_model.state_dict() + for key, val in model.state_dict().items(): + assert torch.equal(val, zero_state_dict[key].to(val.device)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_zero_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_state_dict(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_state_dict(2) diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81855ff5e10a0e91c3e9b8b90c1527bd3b488c8c --- /dev/null +++ b/tests/test_zero/test_tensor_utils.py @@ -0,0 +1,96 @@ +import pytest + +import colossalai +from colossalai.utils.cuda import get_current_device +from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, + colo_model_tensor_clone) +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use + +import torch + +from functools import partial +import torch.multiprocessing as mp + + +def _run_colo_tensor_mem_usage(): + for i in range(1): + if i == 1: + t1 = StatefulTensor(torch.randn(2, 2)) + t2 = StatefulTensor(torch.randn(4, 4)) + c1, g1 = colo_tensor_mem_usage(t1) + c2, g2 = colo_tensor_mem_usage(t2) + assert c1 * 4 == c2 + assert g1 * 4 == g2 + else: + t1 = torch.randn(2, 2) + t2 = torch.randn(4, 4) + c1, g1 = colo_tensor_mem_usage(t1) + c2, g2 = colo_tensor_mem_usage(t2) + assert c1 * 4 == c2 + assert g1 * 4 == g2 + + +def _run_colo_model_data_tensor_move_inline(): + for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: + colo_model_data_tensor_move_inline(t, get_current_device()) + assert t.device == get_current_device() + + +def _run_colo_model_data_tensor_move(): + for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))), + (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]: + cpu_t, cuda_t = t + colo_model_data_tensor_move(cpu_t, cuda_t) + assert cuda_t.device == get_current_device() + + +def _run_colo_model_data_move_to_cpu(): + for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]: + colo_model_data_move_to_cpu(t) + assert t.device == torch.device("cpu") + + +def _run_colo_model_tensor_clone(): + for t in [ + StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())), + torch.randn(4, 4).cuda(torch.cuda.current_device()) + ]: + if issubclass(type(t), StatefulTensor): + assert t.payload.device == get_current_device() + else: + assert t.device == get_current_device() + p = colo_model_tensor_clone(t, get_current_device()) + assert p.device == get_current_device() + for i in range(2): + for j in range(2): + if issubclass(type(t), StatefulTensor): + assert t.payload.device == p.device + assert t.payload[i][j] == p[i][j] + else: + assert t.device == p.device + assert t[i][j] == p[i][j] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + _run_colo_tensor_mem_usage() + _run_colo_model_data_tensor_move_inline() + _run_colo_model_data_tensor_move() + _run_colo_model_data_move_to_cpu() + _run_colo_model_tensor_clone() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_zero_tensor_utils(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_tensor_utils(world_size=2) diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_zero_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..80ded65d634c54acd27b11c64dd060acdd670454 --- /dev/null +++ b/tests/test_zero/test_zero_engine.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from torch.nn.parallel import DistributedDataParallel as DDP + +from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params) + + +def run_dist(rank, world_size, port, parallel_config): + colossalai.launch(config=parallel_config, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + + test_models = ['repeated_computed_layers', 'resnet18', 'bert'] + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + with ZeroInitContext(target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True): + colo_model = model_builder(checkpoint=True) + + colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) + engine, train_dataloader, _, _ = colossalai.initialize(colo_model, + optimizer=colo_optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + torch_model = model_builder(checkpoint=True).half() + col_model_deepcopy(engine.model, torch_model) + torch_model = torch_model.cuda().float() + + engine.train() + torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) + + if dist.get_world_size() > 1: + torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()]) + + i = 0 + for data, label in train_dataloader: + if i > 4: + break + + data, label = data.cuda(), label.cuda() + + engine.zero_grad() + torch_optimizer.zero_grad() + + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + + torch_output = torch_model(data) + torch_loss = engine.criterion(torch_output, label) + else: + loss = engine(data, label) + torch_loss = torch_model(data, label) + + engine.backward(loss) + engine.step() + + torch_loss.backward() + + for param in torch_model.parameters(): + if param.grad is not None: + assert not has_inf_or_nan(param.grad) + + torch_optimizer.step() + i += 1 + + if parallel_config == MP_PARALLEL_CONFIG: + check_params(torch_model, colo_model, loose=True) + elif parallel_config == ZERO_PARALLEL_CONFIG: + check_sharded_model_params(torch_model, colo_model, loose=True) + + +# FIXME: enable this test in next PR +@pytest.mark.skip +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_mp_engine(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) + mp.spawn(run_func, nprocs=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_engine(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_engine(world_size=4) diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ac4e5e38f1e39e79557565a854f377c71d12766 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.1.13